vendor: Update dependencies

This commit is contained in:
Matthew Holt 2017-07-27 16:11:56 -06:00
parent 74940af624
commit a48e4ecb5a
No known key found for this signature in database
GPG Key ID: 2A349DD577D586A5
185 changed files with 24095 additions and 13722 deletions

View File

@ -41,71 +41,71 @@ DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32
#define ROTL(n, t, v) \ #define ROTL(n, t, v) \
VPSLLD $n, v, t; \ VPSLLD $n, v, t; \
VPSRLD $(32-n), v, v; \ VPSRLD $(32-n), v, v; \
VPXOR v, t, v VPXOR v, t, v
#define CHACHA_QROUND(v0, v1, v2, v3, t, c16, c8) \ #define CHACHA_QROUND(v0, v1, v2, v3, t, c16, c8) \
VPADDD v0, v1, v0; \ VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \ VPXOR v3, v0, v3; \
VPSHUFB c16, v3, v3; \ VPSHUFB c16, v3, v3; \
VPADDD v2, v3, v2; \ VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \ VPXOR v1, v2, v1; \
ROTL(12, t, v1); \ ROTL(12, t, v1); \
VPADDD v0, v1, v0; \ VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \ VPXOR v3, v0, v3; \
VPSHUFB c8, v3, v3; \ VPSHUFB c8, v3, v3; \
VPADDD v2, v3, v2; \ VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \ VPXOR v1, v2, v1; \
ROTL(7, t, v1) ROTL(7, t, v1)
#define CHACHA_SHUFFLE(v1, v2, v3) \ #define CHACHA_SHUFFLE(v1, v2, v3) \
VPSHUFD $0x39, v1, v1; \ VPSHUFD $0x39, v1, v1; \
VPSHUFD $0x4E, v2, v2; \ VPSHUFD $0x4E, v2, v2; \
VPSHUFD $-109, v3, v3 VPSHUFD $-109, v3, v3
#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ #define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \ VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \ VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \ VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \ VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \ VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \ VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \ VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \ VMOVDQU t0, (32+off)(dst); \
VMOVDQU (64+off)(src), t0; \ VMOVDQU (64+off)(src), t0; \
VPERM2I128 $49, v1, v0, t1; \ VPERM2I128 $49, v1, v0, t1; \
VPXOR t0, t1, t0; \ VPXOR t0, t1, t0; \
VMOVDQU t0, (64+off)(dst); \ VMOVDQU t0, (64+off)(dst); \
VMOVDQU (96+off)(src), t0; \ VMOVDQU (96+off)(src), t0; \
VPERM2I128 $49, v3, v2, t1; \ VPERM2I128 $49, v3, v2, t1; \
VPXOR t0, t1, t0; \ VPXOR t0, t1, t0; \
VMOVDQU t0, (96+off)(dst) VMOVDQU t0, (96+off)(dst)
#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \ #define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \ VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \ VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \ VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \ VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \ VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \ VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \ VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \ VMOVDQU t0, (32+off)(dst); \
#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \ #define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \
VPERM2I128 $49, v1, v0, t0; \ VPERM2I128 $49, v1, v0, t0; \
VMOVDQU t0, 0(dst); \ VMOVDQU t0, 0(dst); \
VPERM2I128 $49, v3, v2, t0; \ VPERM2I128 $49, v3, v2, t0; \
VMOVDQU t0, 32(dst) VMOVDQU t0, 32(dst)
#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t0) \ #define XOR_AVX(dst, src, off, v0, v1, v2, v3, t0) \
VPXOR 0+off(src), v0, t0; \ VPXOR 0+off(src), v0, t0; \
VMOVDQU t0, 0+off(dst); \ VMOVDQU t0, 0+off(dst); \
VPXOR 16+off(src), v1, t0; \ VPXOR 16+off(src), v1, t0; \
VMOVDQU t0, 16+off(dst); \ VMOVDQU t0, 16+off(dst); \
VPXOR 32+off(src), v2, t0; \ VPXOR 32+off(src), v2, t0; \
VMOVDQU t0, 32+off(dst); \ VMOVDQU t0, 32+off(dst); \
VPXOR 48+off(src), v3, t0; \ VPXOR 48+off(src), v3, t0; \
VMOVDQU t0, 48+off(dst) VMOVDQU t0, 48+off(dst)
#define TWO 0(SP) #define TWO 0(SP)
@ -119,417 +119,424 @@ GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32
#define TMP_1 256(SP) #define TMP_1 256(SP)
// func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int // func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamAVX2(SB),4,$320-80 TEXT ·xorKeyStreamAVX2(SB), 4, $320-80
MOVQ dst_base+0(FP), DI MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI MOVQ src_base+24(FP), SI
MOVQ src_len+32(FP), CX MOVQ src_len+32(FP), CX
MOVQ block+48(FP), BX MOVQ block+48(FP), BX
MOVQ state+56(FP), AX MOVQ state+56(FP), AX
MOVQ rounds+64(FP), DX MOVQ rounds+64(FP), DX
MOVQ SP, R8 MOVQ SP, R8
ADDQ $32, SP ADDQ $32, SP
ANDQ $-32, SP ANDQ $-32, SP
VMOVDQU 0(AX), Y2 VMOVDQU 0(AX), Y2
VMOVDQU 32(AX), Y3 VMOVDQU 32(AX), Y3
VPERM2I128 $0x22, Y2, Y0, Y0 VPERM2I128 $0x22, Y2, Y0, Y0
VPERM2I128 $0x33, Y2, Y1, Y1 VPERM2I128 $0x33, Y2, Y1, Y1
VPERM2I128 $0x22, Y3, Y2, Y2 VPERM2I128 $0x22, Y3, Y2, Y2
VPERM2I128 $0x33, Y3, Y3, Y3 VPERM2I128 $0x33, Y3, Y3, Y3
TESTQ CX, CX TESTQ CX, CX
JZ done JZ done
VMOVDQU ·one_AVX2<>(SB), Y4 VMOVDQU ·one_AVX2<>(SB), Y4
VPADDD Y4, Y3, Y3 VPADDD Y4, Y3, Y3
VMOVDQA Y0, STATE_0 VMOVDQA Y0, STATE_0
VMOVDQA Y1, STATE_1 VMOVDQA Y1, STATE_1
VMOVDQA Y2, STATE_2 VMOVDQA Y2, STATE_2
VMOVDQA Y3, STATE_3 VMOVDQA Y3, STATE_3
VMOVDQU ·rol16_AVX2<>(SB), Y4 VMOVDQU ·rol16_AVX2<>(SB), Y4
VMOVDQU ·rol8_AVX2<>(SB), Y5 VMOVDQU ·rol8_AVX2<>(SB), Y5
VMOVDQU ·two_AVX2<>(SB), Y6 VMOVDQU ·two_AVX2<>(SB), Y6
VMOVDQA Y4, Y14 VMOVDQA Y4, Y14
VMOVDQA Y5, Y15 VMOVDQA Y5, Y15
VMOVDQA Y4, C16 VMOVDQA Y4, C16
VMOVDQA Y5, C8 VMOVDQA Y5, C8
VMOVDQA Y6, TWO VMOVDQA Y6, TWO
CMPQ CX, $64
JBE between_0_and_64
CMPQ CX, $192
JBE between_64_and_192
CMPQ CX, $320
JBE between_192_and_320
CMPQ CX, $448
JBE between_320_and_448
CMPQ CX, $64
JBE between_0_and_64
CMPQ CX, $192
JBE between_64_and_192
CMPQ CX, $320
JBE between_192_and_320
CMPQ CX, $448
JBE between_320_and_448
at_least_512: at_least_512:
VMOVDQA Y0, Y4 VMOVDQA Y0, Y4
VMOVDQA Y1, Y5 VMOVDQA Y1, Y5
VMOVDQA Y2, Y6 VMOVDQA Y2, Y6
VPADDQ TWO, Y3, Y7 VPADDQ TWO, Y3, Y7
VMOVDQA Y0, Y8 VMOVDQA Y0, Y8
VMOVDQA Y1, Y9 VMOVDQA Y1, Y9
VMOVDQA Y2, Y10 VMOVDQA Y2, Y10
VPADDQ TWO, Y7, Y11 VPADDQ TWO, Y7, Y11
VMOVDQA Y0, Y12 VMOVDQA Y0, Y12
VMOVDQA Y1, Y13 VMOVDQA Y1, Y13
VMOVDQA Y2, Y14 VMOVDQA Y2, Y14
VPADDQ TWO, Y11, Y15 VPADDQ TWO, Y11, Y15
MOVQ DX, R9
MOVQ DX, R9
chacha_loop_512: chacha_loop_512:
VMOVDQA Y8, TMP_0 VMOVDQA Y8, TMP_0
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8) CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8)
VMOVDQA TMP_0, Y8 VMOVDQA TMP_0, Y8
VMOVDQA Y0, TMP_0 VMOVDQA Y0, TMP_0
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8) CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8)
CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8) CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_SHUFFLE(Y1, Y2, Y3) CHACHA_SHUFFLE(Y1, Y2, Y3)
CHACHA_SHUFFLE(Y5, Y6, Y7) CHACHA_SHUFFLE(Y5, Y6, Y7)
CHACHA_SHUFFLE(Y9, Y10, Y11) CHACHA_SHUFFLE(Y9, Y10, Y11)
CHACHA_SHUFFLE(Y13, Y14, Y15) CHACHA_SHUFFLE(Y13, Y14, Y15)
CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8)
VMOVDQA TMP_0, Y0
VMOVDQA Y8, TMP_0
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8)
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8)
VMOVDQA TMP_0, Y8
CHACHA_SHUFFLE(Y3, Y2, Y1)
CHACHA_SHUFFLE(Y7, Y6, Y5)
CHACHA_SHUFFLE(Y11, Y10, Y9)
CHACHA_SHUFFLE(Y15, Y14, Y13)
SUBQ $2, R9
JA chacha_loop_512
VMOVDQA Y12, TMP_0 CHACHA_QROUND(Y12, Y13, Y14, Y15, Y0, C16, C8)
VMOVDQA Y13, TMP_1 CHACHA_QROUND(Y8, Y9, Y10, Y11, Y0, C16, C8)
VPADDD STATE_0, Y0, Y0 VMOVDQA TMP_0, Y0
VPADDD STATE_1, Y1, Y1 VMOVDQA Y8, TMP_0
VPADDD STATE_2, Y2, Y2 CHACHA_QROUND(Y4, Y5, Y6, Y7, Y8, C16, C8)
VPADDD STATE_3, Y3, Y3 CHACHA_QROUND(Y0, Y1, Y2, Y3, Y8, C16, C8)
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13) VMOVDQA TMP_0, Y8
VMOVDQA STATE_0, Y0 CHACHA_SHUFFLE(Y3, Y2, Y1)
VMOVDQA STATE_1, Y1 CHACHA_SHUFFLE(Y7, Y6, Y5)
VMOVDQA STATE_2, Y2 CHACHA_SHUFFLE(Y11, Y10, Y9)
VMOVDQA STATE_3, Y3 CHACHA_SHUFFLE(Y15, Y14, Y13)
VPADDQ TWO, Y3, Y3 SUBQ $2, R9
JA chacha_loop_512
VPADDD Y0, Y4, Y4 VMOVDQA Y12, TMP_0
VPADDD Y1, Y5, Y5 VMOVDQA Y13, TMP_1
VPADDD Y2, Y6, Y6 VPADDD STATE_0, Y0, Y0
VPADDD Y3, Y7, Y7 VPADDD STATE_1, Y1, Y1
XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13) VPADDD STATE_2, Y2, Y2
VPADDQ TWO, Y3, Y3 VPADDD STATE_3, Y3, Y3
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13)
VMOVDQA STATE_0, Y0
VMOVDQA STATE_1, Y1
VMOVDQA STATE_2, Y2
VMOVDQA STATE_3, Y3
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y8, Y8 VPADDD Y0, Y4, Y4
VPADDD Y1, Y9, Y9 VPADDD Y1, Y5, Y5
VPADDD Y2, Y10, Y10 VPADDD Y2, Y6, Y6
VPADDD Y3, Y11, Y11 VPADDD Y3, Y7, Y7
XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13) XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDQ TWO, Y3, Y3 VPADDQ TWO, Y3, Y3
VPADDD TMP_0, Y0, Y12 VPADDD Y0, Y8, Y8
VPADDD TMP_1, Y1, Y13 VPADDD Y1, Y9, Y9
VPADDD Y2, Y14, Y14 VPADDD Y2, Y10, Y10
VPADDD Y3, Y15, Y15 VPADDD Y3, Y11, Y11
VPADDQ TWO, Y3, Y3 XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
VPADDQ TWO, Y3, Y3
CMPQ CX, $512 VPADDD TMP_0, Y0, Y12
JB less_than_512 VPADDD TMP_1, Y1, Y13
VPADDD Y2, Y14, Y14
VPADDD Y3, Y15, Y15
VPADDQ TWO, Y3, Y3
XOR_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5) CMPQ CX, $512
VMOVDQA Y3, STATE_3 JB less_than_512
ADDQ $512, SI
ADDQ $512, DI
SUBQ $512, CX
CMPQ CX, $448
JA at_least_512
TESTQ CX, CX XOR_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5)
JZ done VMOVDQA Y3, STATE_3
ADDQ $512, SI
ADDQ $512, DI
SUBQ $512, CX
CMPQ CX, $448
JA at_least_512
VMOVDQA C16, Y14 TESTQ CX, CX
VMOVDQA C8, Y15 JZ done
CMPQ CX, $64 VMOVDQA C16, Y14
JBE between_0_and_64 VMOVDQA C8, Y15
CMPQ CX, $192
JBE between_64_and_192 CMPQ CX, $64
CMPQ CX, $320 JBE between_0_and_64
JBE between_192_and_320 CMPQ CX, $192
JMP between_320_and_448 JBE between_64_and_192
CMPQ CX, $320
JBE between_192_and_320
JMP between_320_and_448
less_than_512: less_than_512:
XOR_UPPER_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5) XOR_UPPER_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5)
EXTRACT_LOWER(BX, Y12, Y13, Y14, Y15, Y4) EXTRACT_LOWER(BX, Y12, Y13, Y14, Y15, Y4)
ADDQ $448, SI ADDQ $448, SI
ADDQ $448, DI ADDQ $448, DI
SUBQ $448, CX SUBQ $448, CX
JMP finalize JMP finalize
between_320_and_448: between_320_and_448:
VMOVDQA Y0, Y4 VMOVDQA Y0, Y4
VMOVDQA Y1, Y5 VMOVDQA Y1, Y5
VMOVDQA Y2, Y6 VMOVDQA Y2, Y6
VPADDQ TWO, Y3, Y7 VPADDQ TWO, Y3, Y7
VMOVDQA Y0, Y8 VMOVDQA Y0, Y8
VMOVDQA Y1, Y9 VMOVDQA Y1, Y9
VMOVDQA Y2, Y10 VMOVDQA Y2, Y10
VPADDQ TWO, Y7, Y11 VPADDQ TWO, Y7, Y11
MOVQ DX, R9
MOVQ DX, R9
chacha_loop_384: chacha_loop_384:
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y1, Y2, Y3) CHACHA_SHUFFLE(Y1, Y2, Y3)
CHACHA_SHUFFLE(Y5, Y6, Y7) CHACHA_SHUFFLE(Y5, Y6, Y7)
CHACHA_SHUFFLE(Y9, Y10, Y11) CHACHA_SHUFFLE(Y9, Y10, Y11)
CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15) CHACHA_QROUND(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y3, Y2, Y1) CHACHA_SHUFFLE(Y3, Y2, Y1)
CHACHA_SHUFFLE(Y7, Y6, Y5) CHACHA_SHUFFLE(Y7, Y6, Y5)
CHACHA_SHUFFLE(Y11, Y10, Y9) CHACHA_SHUFFLE(Y11, Y10, Y9)
SUBQ $2, R9 SUBQ $2, R9
JA chacha_loop_384 JA chacha_loop_384
VPADDD STATE_0, Y0, Y0 VPADDD STATE_0, Y0, Y0
VPADDD STATE_1, Y1, Y1 VPADDD STATE_1, Y1, Y1
VPADDD STATE_2, Y2, Y2 VPADDD STATE_2, Y2, Y2
VPADDD STATE_3, Y3, Y3 VPADDD STATE_3, Y3, Y3
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13) XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13)
VMOVDQA STATE_0, Y0 VMOVDQA STATE_0, Y0
VMOVDQA STATE_1, Y1 VMOVDQA STATE_1, Y1
VMOVDQA STATE_2, Y2 VMOVDQA STATE_2, Y2
VMOVDQA STATE_3, Y3 VMOVDQA STATE_3, Y3
VPADDQ TWO, Y3, Y3 VPADDQ TWO, Y3, Y3
VPADDD Y0, Y4, Y4 VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5 VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6 VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7 VPADDD Y3, Y7, Y7
XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13) XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDQ TWO, Y3, Y3 VPADDQ TWO, Y3, Y3
VPADDD Y0, Y8, Y8 VPADDD Y0, Y8, Y8
VPADDD Y1, Y9, Y9 VPADDD Y1, Y9, Y9
VPADDD Y2, Y10, Y10 VPADDD Y2, Y10, Y10
VPADDD Y3, Y11, Y11 VPADDD Y3, Y11, Y11
VPADDQ TWO, Y3, Y3 VPADDQ TWO, Y3, Y3
CMPQ CX, $384 CMPQ CX, $384
JB less_than_384 JB less_than_384
XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
SUBQ $384, CX
TESTQ CX, CX
JE done
ADDQ $384, SI XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
ADDQ $384, DI SUBQ $384, CX
JMP between_0_and_64 TESTQ CX, CX
JE done
ADDQ $384, SI
ADDQ $384, DI
JMP between_0_and_64
less_than_384: less_than_384:
XOR_UPPER_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13) XOR_UPPER_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12) EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12)
ADDQ $320, SI ADDQ $320, SI
ADDQ $320, DI ADDQ $320, DI
SUBQ $320, CX SUBQ $320, CX
JMP finalize JMP finalize
between_192_and_320: between_192_and_320:
VMOVDQA Y0, Y4 VMOVDQA Y0, Y4
VMOVDQA Y1, Y5 VMOVDQA Y1, Y5
VMOVDQA Y2, Y6 VMOVDQA Y2, Y6
VMOVDQA Y3, Y7 VMOVDQA Y3, Y7
VMOVDQA Y0, Y8 VMOVDQA Y0, Y8
VMOVDQA Y1, Y9 VMOVDQA Y1, Y9
VMOVDQA Y2, Y10 VMOVDQA Y2, Y10
VPADDQ TWO, Y3, Y11 VPADDQ TWO, Y3, Y11
MOVQ DX, R9 MOVQ DX, R9
chacha_loop_256: chacha_loop_256:
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y5, Y6, Y7) CHACHA_SHUFFLE(Y5, Y6, Y7)
CHACHA_SHUFFLE(Y9, Y10, Y11) CHACHA_SHUFFLE(Y9, Y10, Y11)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15) CHACHA_QROUND(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y7, Y6, Y5) CHACHA_SHUFFLE(Y7, Y6, Y5)
CHACHA_SHUFFLE(Y11, Y10, Y9) CHACHA_SHUFFLE(Y11, Y10, Y9)
SUBQ $2, R9 SUBQ $2, R9
JA chacha_loop_256 JA chacha_loop_256
VPADDD Y0, Y4, Y4 VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5 VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6 VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7 VPADDD Y3, Y7, Y7
VPADDQ TWO, Y3, Y3 VPADDQ TWO, Y3, Y3
XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDD Y0, Y8, Y8 VPADDD Y0, Y8, Y8
VPADDD Y1, Y9, Y9 VPADDD Y1, Y9, Y9
VPADDD Y2, Y10, Y10 VPADDD Y2, Y10, Y10
VPADDD Y3, Y11, Y11 VPADDD Y3, Y11, Y11
VPADDQ TWO, Y3, Y3 VPADDQ TWO, Y3, Y3
CMPQ CX, $256 CMPQ CX, $256
JB less_than_256 JB less_than_256
XOR_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13) XOR_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13)
SUBQ $256, CX SUBQ $256, CX
TESTQ CX, CX TESTQ CX, CX
JE done JE done
ADDQ $256, SI
ADDQ $256, DI
JMP between_0_and_64
ADDQ $256, SI
ADDQ $256, DI
JMP between_0_and_64
less_than_256: less_than_256:
XOR_UPPER_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13) XOR_UPPER_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13)
EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12) EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12)
ADDQ $192, SI ADDQ $192, SI
ADDQ $192, DI ADDQ $192, DI
SUBQ $192, CX SUBQ $192, CX
JMP finalize JMP finalize
between_64_and_192: between_64_and_192:
VMOVDQA Y0, Y4 VMOVDQA Y0, Y4
VMOVDQA Y1, Y5 VMOVDQA Y1, Y5
VMOVDQA Y2, Y6 VMOVDQA Y2, Y6
VMOVDQA Y3, Y7 VMOVDQA Y3, Y7
MOVQ DX, R9
MOVQ DX, R9
chacha_loop_128: chacha_loop_128:
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y5, Y6, Y7) CHACHA_SHUFFLE(Y5, Y6, Y7)
CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15) CHACHA_QROUND(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE(Y7, Y6, Y5) CHACHA_SHUFFLE(Y7, Y6, Y5)
SUBQ $2, R9 SUBQ $2, R9
JA chacha_loop_128 JA chacha_loop_128
VPADDD Y0, Y4, Y4 VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5 VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6 VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7 VPADDD Y3, Y7, Y7
VPADDQ TWO, Y3, Y3 VPADDQ TWO, Y3, Y3
CMPQ CX, $128 CMPQ CX, $128
JB less_than_128 JB less_than_128
XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
SUBQ $128, CX SUBQ $128, CX
TESTQ CX, CX TESTQ CX, CX
JE done JE done
ADDQ $128, SI
ADDQ $128, DI
JMP between_0_and_64
ADDQ $128, SI
ADDQ $128, DI
JMP between_0_and_64
less_than_128: less_than_128:
XOR_UPPER_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13) XOR_UPPER_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
EXTRACT_LOWER(BX, Y4, Y5, Y6, Y7, Y13) EXTRACT_LOWER(BX, Y4, Y5, Y6, Y7, Y13)
ADDQ $64, SI ADDQ $64, SI
ADDQ $64, DI ADDQ $64, DI
SUBQ $64, CX SUBQ $64, CX
JMP finalize JMP finalize
between_0_and_64: between_0_and_64:
VMOVDQA X0, X4 VMOVDQA X0, X4
VMOVDQA X1, X5 VMOVDQA X1, X5
VMOVDQA X2, X6 VMOVDQA X2, X6
VMOVDQA X3, X7 VMOVDQA X3, X7
MOVQ DX, R9
MOVQ DX, R9
chacha_loop_64: chacha_loop_64:
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE(X5, X6, X7) CHACHA_SHUFFLE(X5, X6, X7)
CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15) CHACHA_QROUND(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE(X7, X6, X5) CHACHA_SHUFFLE(X7, X6, X5)
SUBQ $2, R9 SUBQ $2, R9
JA chacha_loop_64 JA chacha_loop_64
VPADDD X0, X4, X4 VPADDD X0, X4, X4
VPADDD X1, X5, X5 VPADDD X1, X5, X5
VPADDD X2, X6, X6 VPADDD X2, X6, X6
VPADDD X3, X7, X7 VPADDD X3, X7, X7
VMOVDQU ·one_AVX<>(SB), X0 VMOVDQU ·one_AVX<>(SB), X0
VPADDQ X0, X3, X3 VPADDQ X0, X3, X3
CMPQ CX, $64 CMPQ CX, $64
JB less_than_64 JB less_than_64
XOR_AVX(DI, SI, 0, X4, X5, X6, X7, X13) XOR_AVX(DI, SI, 0, X4, X5, X6, X7, X13)
SUBQ $64, CX SUBQ $64, CX
JMP done JMP done
less_than_64: less_than_64:
VMOVDQU X4, 0(BX) VMOVDQU X4, 0(BX)
VMOVDQU X5, 16(BX) VMOVDQU X5, 16(BX)
VMOVDQU X6, 32(BX) VMOVDQU X6, 32(BX)
VMOVDQU X7, 48(BX) VMOVDQU X7, 48(BX)
finalize: finalize:
XORQ R11, R11 XORQ R11, R11
XORQ R12, R12 XORQ R12, R12
MOVQ CX, BP MOVQ CX, BP
xor_loop: xor_loop:
MOVB 0(SI), R11 MOVB 0(SI), R11
MOVB 0(BX), R12 MOVB 0(BX), R12
XORQ R11, R12 XORQ R11, R12
MOVB R12, 0(DI) MOVB R12, 0(DI)
INCQ SI INCQ SI
INCQ BX INCQ BX
INCQ DI INCQ DI
DECQ BP DECQ BP
JA xor_loop JA xor_loop
done: done:
VMOVDQU X3, 48(AX) VMOVDQU X3, 48(AX)
VZEROUPPER VZEROUPPER
MOVQ R8, SP MOVQ R8, SP
MOVQ CX, ret+72(FP) MOVQ CX, ret+72(FP)
RET RET
// func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte) // func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20AVX(SB), 4, $0-24 TEXT ·hChaCha20AVX(SB), 4, $0-24
MOVQ out+0(FP), DI MOVQ out+0(FP), DI
MOVQ nonce+8(FP), AX MOVQ nonce+8(FP), AX
MOVQ key+16(FP), BX MOVQ key+16(FP), BX
VMOVDQU ·sigma_AVX<>(SB), X0 VMOVDQU ·sigma_AVX<>(SB), X0
VMOVDQU 0(BX), X1 VMOVDQU 0(BX), X1
VMOVDQU 16(BX), X2 VMOVDQU 16(BX), X2
VMOVDQU 0(AX), X3 VMOVDQU 0(AX), X3
VMOVDQU ·rol16_AVX2<>(SB), X5 VMOVDQU ·rol16_AVX2<>(SB), X5
VMOVDQU ·rol8_AVX2<>(SB), X6 VMOVDQU ·rol8_AVX2<>(SB), X6
MOVQ $20, CX
MOVQ $20, CX
chacha_loop: chacha_loop:
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X1, X2, X3) CHACHA_SHUFFLE(X1, X2, X3)
CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6) CHACHA_QROUND(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X3, X2, X1) CHACHA_SHUFFLE(X3, X2, X1)
SUBQ $2, CX SUBQ $2, CX
JNZ chacha_loop JNZ chacha_loop
VMOVDQU X0, 0(DI) VMOVDQU X0, 0(DI)
VMOVDQU X3, 16(DI) VMOVDQU X3, 16(DI)
VZEROUPPER VZEROUPPER
RET RET
// func supportsAVX2() bool // func supportsAVX2() bool
TEXT ·supportsAVX2(SB), 4, $0-1 TEXT ·supportsAVX2(SB), 4, $0-1
MOVQ runtime·support_avx(SB), AX MOVQ runtime·support_avx(SB), AX
MOVQ runtime·support_avx2(SB), BX MOVQ runtime·support_avx2(SB), BX
ANDQ AX, BX ANDQ AX, BX
MOVB BX, ret+0(FP) MOVB BX, ret+0(FP)
RET RET

View File

@ -25,231 +25,233 @@ DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16
#define ROTL_SSE2(n, t, v) \ #define ROTL_SSE2(n, t, v) \
MOVO v, t; \ MOVO v, t; \
PSLLL $n, t; \ PSLLL $n, t; \
PSRLL $(32-n), v; \ PSRLL $(32-n), v; \
PXOR t, v PXOR t, v
#define CHACHA_QROUND_SSE2(v0 , v1 , v2 , v3 , t0) \ #define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t0) \
PADDL v1, v0; \ PADDL v1, v0; \
PXOR v0, v3; \ PXOR v0, v3; \
ROTL_SSE2(16, t0, v3); \ ROTL_SSE2(16, t0, v3); \
PADDL v3, v2; \ PADDL v3, v2; \
PXOR v2, v1; \ PXOR v2, v1; \
ROTL_SSE2(12, t0, v1); \ ROTL_SSE2(12, t0, v1); \
PADDL v1, v0; \ PADDL v1, v0; \
PXOR v0, v3; \ PXOR v0, v3; \
ROTL_SSE2(8, t0, v3); \ ROTL_SSE2(8, t0, v3); \
PADDL v3, v2; \ PADDL v3, v2; \
PXOR v2, v1; \ PXOR v2, v1; \
ROTL_SSE2(7, t0, v1) ROTL_SSE2(7, t0, v1)
#define CHACHA_QROUND_SSSE3(v0 , v1 , v2 , v3 , t0, r16, r8) \ #define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t0, r16, r8) \
PADDL v1, v0; \ PADDL v1, v0; \
PXOR v0, v3; \ PXOR v0, v3; \
PSHUFB r16, v3; \ PSHUFB r16, v3; \
PADDL v3, v2; \ PADDL v3, v2; \
PXOR v2, v1; \ PXOR v2, v1; \
ROTL_SSE2(12, t0, v1); \ ROTL_SSE2(12, t0, v1); \
PADDL v1, v0; \ PADDL v1, v0; \
PXOR v0, v3; \ PXOR v0, v3; \
PSHUFB r8, v3; \ PSHUFB r8, v3; \
PADDL v3, v2; \ PADDL v3, v2; \
PXOR v2, v1; \ PXOR v2, v1; \
ROTL_SSE2(7, t0, v1) ROTL_SSE2(7, t0, v1)
#define CHACHA_SHUFFLE(v1, v2, v3) \ #define CHACHA_SHUFFLE(v1, v2, v3) \
PSHUFL $0x39, v1, v1; \ PSHUFL $0x39, v1, v1; \
PSHUFL $0x4E, v2, v2; \ PSHUFL $0x4E, v2, v2; \
PSHUFL $0x93, v3, v3 PSHUFL $0x93, v3, v3
#define XOR(dst, src, off, v0 , v1 , v2 , v3 , t0) \ #define XOR(dst, src, off, v0, v1, v2, v3, t0) \
MOVOU 0+off(src), t0; \ MOVOU 0+off(src), t0; \
PXOR v0, t0; \ PXOR v0, t0; \
MOVOU t0, 0+off(dst); \ MOVOU t0, 0+off(dst); \
MOVOU 16+off(src), t0; \ MOVOU 16+off(src), t0; \
PXOR v1, t0; \ PXOR v1, t0; \
MOVOU t0, 16+off(dst); \ MOVOU t0, 16+off(dst); \
MOVOU 32+off(src), t0; \ MOVOU 32+off(src), t0; \
PXOR v2, t0; \ PXOR v2, t0; \
MOVOU t0, 32+off(dst); \ MOVOU t0, 32+off(dst); \
MOVOU 48+off(src), t0; \ MOVOU 48+off(src), t0; \
PXOR v3, t0; \ PXOR v3, t0; \
MOVOU t0, 48+off(dst) MOVOU t0, 48+off(dst)
#define FINALIZE(dst, src, block, len, t0, t1) \ #define FINALIZE(dst, src, block, len, t0, t1) \
XORL t0, t0; \ XORL t0, t0; \
XORL t1, t1; \ XORL t1, t1; \
finalize: \ finalize: \
MOVB 0(src), t0; \ MOVB 0(src), t0; \
MOVB 0(block), t1; \ MOVB 0(block), t1; \
XORL t0, t1; \ XORL t0, t1; \
MOVB t1, 0(dst); \ MOVB t1, 0(dst); \
INCL src; \ INCL src; \
INCL block; \ INCL block; \
INCL dst; \ INCL dst; \
DECL len; \ DECL len; \
JA finalize \ JA finalize \
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int // func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSE2(SB),4,$0-40 TEXT ·xorKeyStreamSSE2(SB), 4, $0-40
MOVL dst_base+0(FP), DI MOVL dst_base+0(FP), DI
MOVL src_base+12(FP), SI MOVL src_base+12(FP), SI
MOVL src_len+16(FP), CX MOVL src_len+16(FP), CX
MOVL state+28(FP), AX MOVL state+28(FP), AX
MOVL rounds+32(FP), DX MOVL rounds+32(FP), DX
MOVOU 0(AX), X0 MOVOU 0(AX), X0
MOVOU 16(AX), X1 MOVOU 16(AX), X1
MOVOU 32(AX), X2 MOVOU 32(AX), X2
MOVOU 48(AX), X3 MOVOU 48(AX), X3
TESTL CX, CX TESTL CX, CX
JZ done JZ done
at_least_64: at_least_64:
MOVO X0, X4 MOVO X0, X4
MOVO X1, X5 MOVO X1, X5
MOVO X2, X6 MOVO X2, X6
MOVO X3, X7 MOVO X3, X7
MOVL DX, BX
MOVL DX, BX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE(X5, X6, X7) CHACHA_SHUFFLE(X5, X6, X7)
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0) CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE(X7, X6, X5) CHACHA_SHUFFLE(X7, X6, X5)
SUBL $2, BX SUBL $2, BX
JA chacha_loop JA chacha_loop
MOVOU 0(AX), X0 MOVOU 0(AX), X0
PADDL X0, X4 PADDL X0, X4
PADDL X1, X5 PADDL X1, X5
PADDL X2, X6 PADDL X2, X6
PADDL X3, X7 PADDL X3, X7
MOVOU ·one<>(SB), X0 MOVOU ·one<>(SB), X0
PADDQ X0, X3 PADDQ X0, X3
CMPL CX, $64 CMPL CX, $64
JB less_than_64 JB less_than_64
XOR(DI, SI, 0, X4, X5, X6, X7, X0) XOR(DI, SI, 0, X4, X5, X6, X7, X0)
MOVOU 0(AX), X0 MOVOU 0(AX), X0
ADDL $64, SI ADDL $64, SI
ADDL $64, DI ADDL $64, DI
SUBL $64, CX SUBL $64, CX
JNZ at_least_64 JNZ at_least_64
less_than_64: less_than_64:
MOVL CX, BP MOVL CX, BP
TESTL BP, BP TESTL BP, BP
JZ done JZ done
MOVL block+24(FP), BX MOVL block+24(FP), BX
MOVOU X4, 0(BX) MOVOU X4, 0(BX)
MOVOU X5, 16(BX) MOVOU X5, 16(BX)
MOVOU X6, 32(BX) MOVOU X6, 32(BX)
MOVOU X7, 48(BX) MOVOU X7, 48(BX)
FINALIZE(DI, SI, BX, BP, AX, DX) FINALIZE(DI, SI, BX, BP, AX, DX)
done: done:
MOVL state+28(FP), AX MOVL state+28(FP), AX
MOVOU X3, 48(AX) MOVOU X3, 48(AX)
MOVL CX, ret+36(FP) MOVL CX, ret+36(FP)
RET RET
// func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int // func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSSE3(SB),4,$64-40 TEXT ·xorKeyStreamSSSE3(SB), 4, $64-40
MOVL dst_base+0(FP), DI MOVL dst_base+0(FP), DI
MOVL src_base+12(FP), SI MOVL src_base+12(FP), SI
MOVL src_len+16(FP), CX MOVL src_len+16(FP), CX
MOVL state+28(FP), AX MOVL state+28(FP), AX
MOVL rounds+32(FP), DX MOVL rounds+32(FP), DX
MOVOU 48(AX), X3 MOVOU 48(AX), X3
TESTL CX, CX TESTL CX, CX
JZ done JZ done
MOVL SP, BP MOVL SP, BP
ADDL $16, SP ADDL $16, SP
ANDL $-16, SP ANDL $-16, SP
MOVOU ·one<>(SB), X0 MOVOU ·one<>(SB), X0
MOVOU 16(AX), X1 MOVOU 16(AX), X1
MOVOU 32(AX), X2 MOVOU 32(AX), X2
MOVO X0, 0(SP) MOVO X0, 0(SP)
MOVO X1, 16(SP) MOVO X1, 16(SP)
MOVO X2, 32(SP) MOVO X2, 32(SP)
MOVOU 0(AX), X0
MOVOU ·rol16<>(SB), X1
MOVOU ·rol8<>(SB), X2
MOVOU 0(AX), X0
MOVOU ·rol16<>(SB), X1
MOVOU ·rol8<>(SB), X2
at_least_64: at_least_64:
MOVO X0, X4 MOVO X0, X4
MOVO 16(SP), X5 MOVO 16(SP), X5
MOVO 32(SP), X6 MOVO 32(SP), X6
MOVO X3, X7 MOVO X3, X7
MOVL DX, BX
MOVL DX, BX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2)
CHACHA_SHUFFLE(X5, X6, X7) CHACHA_SHUFFLE(X5, X6, X7)
CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2) CHACHA_QROUND_SSSE3(X4, X5, X6, X7, X0, X1, X2)
CHACHA_SHUFFLE(X7, X6, X5) CHACHA_SHUFFLE(X7, X6, X5)
SUBL $2, BX SUBL $2, BX
JA chacha_loop JA chacha_loop
MOVOU 0(AX), X0 MOVOU 0(AX), X0
PADDL X0, X4 PADDL X0, X4
PADDL 16(SP), X5 PADDL 16(SP), X5
PADDL 32(SP), X6 PADDL 32(SP), X6
PADDL X3, X7 PADDL X3, X7
PADDQ 0(SP), X3 PADDQ 0(SP), X3
CMPL CX, $64 CMPL CX, $64
JB less_than_64 JB less_than_64
XOR(DI, SI, 0, X4, X5, X6, X7, X0) XOR(DI, SI, 0, X4, X5, X6, X7, X0)
MOVOU 0(AX), X0 MOVOU 0(AX), X0
ADDL $64, SI ADDL $64, SI
ADDL $64, DI ADDL $64, DI
SUBL $64, CX SUBL $64, CX
JNZ at_least_64 JNZ at_least_64
less_than_64: less_than_64:
MOVL BP, SP MOVL BP, SP
MOVL CX, BP MOVL CX, BP
TESTL BP, BP TESTL BP, BP
JE done JE done
MOVL block+24(FP), BX MOVL block+24(FP), BX
MOVOU X4, 0(BX) MOVOU X4, 0(BX)
MOVOU X5, 16(BX) MOVOU X5, 16(BX)
MOVOU X6, 32(BX) MOVOU X6, 32(BX)
MOVOU X7, 48(BX) MOVOU X7, 48(BX)
FINALIZE(DI, SI, BX, BP, AX, DX) FINALIZE(DI, SI, BX, BP, AX, DX)
done: done:
MOVL state+28(FP), AX MOVL state+28(FP), AX
MOVOU X3, 48(AX) MOVOU X3, 48(AX)
MOVL CX, ret+36(FP) MOVL CX, ret+36(FP)
RET RET
// func supportsSSE2() bool // func supportsSSE2() bool
TEXT ·supportsSSE2(SB), NOSPLIT, $0-1 TEXT ·supportsSSE2(SB), NOSPLIT, $0-1
XORL AX, AX XORL AX, AX
INCL AX INCL AX
CPUID CPUID
SHRL $26, DX SHRL $26, DX
ANDL $1, DX ANDL $1, DX
MOVB DX, ret+0(FP) MOVB DX, ret+0(FP)
RET RET
// func supportsSSSE3() bool // func supportsSSSE3() bool
TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1 TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1
XORL AX, AX XORL AX, AX
INCL AX INCL AX
CPUID CPUID
SHRL $9, CX SHRL $9, CX
ANDL $1, CX ANDL $1, CX
@ -258,50 +260,52 @@ TEXT ·supportsSSSE3(SB), NOSPLIT, $0-1
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte) // func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSE2(SB), 4, $0-12 TEXT ·hChaCha20SSE2(SB), 4, $0-12
MOVL out+0(FP), DI MOVL out+0(FP), DI
MOVL nonce+4(FP), AX MOVL nonce+4(FP), AX
MOVL key+8(FP), BX MOVL key+8(FP), BX
MOVOU ·sigma<>(SB), X0 MOVOU ·sigma<>(SB), X0
MOVOU 0(BX), X1 MOVOU 0(BX), X1
MOVOU 16(BX), X2 MOVOU 16(BX), X2
MOVOU 0(AX), X3 MOVOU 0(AX), X3
MOVL $20, CX
MOVL $20, CX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE(X1, X2, X3) CHACHA_SHUFFLE(X1, X2, X3)
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4) CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE(X3, X2, X1) CHACHA_SHUFFLE(X3, X2, X1)
SUBL $2, CX SUBL $2, CX
JNZ chacha_loop JNZ chacha_loop
MOVOU X0, 0(DI) MOVOU X0, 0(DI)
MOVOU X3, 16(DI) MOVOU X3, 16(DI)
RET RET
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte) // func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSSE3(SB), 4, $0-12 TEXT ·hChaCha20SSSE3(SB), 4, $0-12
MOVL out+0(FP), DI MOVL out+0(FP), DI
MOVL nonce+4(FP), AX MOVL nonce+4(FP), AX
MOVL key+8(FP), BX MOVL key+8(FP), BX
MOVOU ·sigma<>(SB), X0 MOVOU ·sigma<>(SB), X0
MOVOU 0(BX), X1 MOVOU 0(BX), X1
MOVOU 16(BX), X2 MOVOU 16(BX), X2
MOVOU 0(AX), X3 MOVOU 0(AX), X3
MOVOU ·rol16<>(SB), X5 MOVOU ·rol16<>(SB), X5
MOVOU ·rol8<>(SB), X6 MOVOU ·rol8<>(SB), X6
MOVL $20, CX
MOVL $20, CX
chacha_loop: chacha_loop:
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X1, X2, X3) CHACHA_SHUFFLE(X1, X2, X3)
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6) CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE(X3, X2, X1) CHACHA_SHUFFLE(X3, X2, X1)
SUBL $2, CX SUBL $2, CX
JNZ chacha_loop JNZ chacha_loop
MOVOU X0, 0(DI) MOVOU X0, 0(DI)
MOVOU X3, 16(DI) MOVOU X3, 16(DI)
RET RET

File diff suppressed because it is too large Load Diff

View File

@ -174,11 +174,11 @@ func sizeFixed32(x uint64) int {
// This is the format used for the sint64 protocol buffer type. // This is the format used for the sint64 protocol buffer type.
func (p *Buffer) EncodeZigzag64(x uint64) error { func (p *Buffer) EncodeZigzag64(x uint64) error {
// use signed number to get arithmetic right shift. // use signed number to get arithmetic right shift.
return p.EncodeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) return p.EncodeVarint((x << 1) ^ uint64((int64(x) >> 63)))
} }
func sizeZigzag64(x uint64) int { func sizeZigzag64(x uint64) int {
return sizeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) return sizeVarint((x << 1) ^ uint64((int64(x) >> 63)))
} }
// EncodeZigzag32 writes a zigzag-encoded 32-bit integer // EncodeZigzag32 writes a zigzag-encoded 32-bit integer

View File

@ -865,7 +865,7 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) error {
return p.readStruct(fv, terminator) return p.readStruct(fv, terminator)
case reflect.Uint32: case reflect.Uint32:
if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil { if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil {
fv.SetUint(uint64(x)) fv.SetUint(x)
return nil return nil
} }
case reflect.Uint64: case reflect.Uint64:

View File

@ -6,9 +6,8 @@
// //
// Overview // Overview
// //
// The Conn type represents a WebSocket connection. A server application uses // The Conn type represents a WebSocket connection. A server application calls
// the Upgrade function from an Upgrader object with a HTTP request handler // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
// to get a pointer to a Conn:
// //
// var upgrader = websocket.Upgrader{ // var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024, // ReadBufferSize: 1024,
@ -33,7 +32,7 @@
// if err != nil { // if err != nil {
// return // return
// } // }
// if err = conn.WriteMessage(messageType, p); err != nil { // if err := conn.WriteMessage(messageType, p); err != nil {
// return err // return err
// } // }
// } // }
@ -147,9 +146,9 @@
// CheckOrigin: func(r *http.Request) bool { return true }, // CheckOrigin: func(r *http.Request) bool { return true },
// } // }
// //
// The deprecated Upgrade function does not enforce an origin policy. It's the // The deprecated package-level Upgrade function does not perform origin
// application's responsibility to check the Origin header before calling // checking. The application is responsible for checking the Origin header
// Upgrade. // before calling the Upgrade function.
// //
// Compression EXPERIMENTAL // Compression EXPERIMENTAL
// //

View File

@ -129,6 +129,9 @@ func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
} }
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
client.hub.register <- client client.hub.register <- client
// Allow collection of memory referenced by the caller by doing all work in
// new goroutines.
go client.writePump() go client.writePump()
client.readPump() go client.readPump()
} }

View File

@ -9,12 +9,14 @@ import (
"io" "io"
) )
// WriteJSON is deprecated, use c.WriteJSON instead. // WriteJSON writes the JSON encoding of v as a message.
//
// Deprecated: Use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error { func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v) return c.WriteJSON(v)
} }
// WriteJSON writes the JSON encoding of v to the connection. // WriteJSON writes the JSON encoding of v as a message.
// //
// See the documentation for encoding/json Marshal for details about the // See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON. // conversion of Go values to JSON.
@ -31,7 +33,10 @@ func (c *Conn) WriteJSON(v interface{}) error {
return err2 return err2
} }
// ReadJSON is deprecated, use c.ReadJSON instead. // ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// Deprecated: Use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error { func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v) return c.ReadJSON(v)
} }

View File

@ -230,10 +230,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
// Upgrade upgrades the HTTP server connection to the WebSocket protocol. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
// //
// This function is deprecated, use websocket.Upgrader instead. // Deprecated: Use websocket.Upgrader instead.
// //
// The application is responsible for checking the request origin before // Upgrade does not perform origin checking. The application is responsible for
// calling Upgrade. An example implementation of the same origin policy is: // checking the Origin header before calling Upgrade. An example implementation
// of the same origin policy check is:
// //
// if req.Header.Get("Origin") != "http://"+req.Host { // if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403) // http.Error(w, "Origin not allowed", 403)

View File

@ -111,14 +111,14 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
case escape: case escape:
escape = false escape = false
p[j] = b p[j] = b
j += 1 j++
case b == '\\': case b == '\\':
escape = true escape = true
case b == '"': case b == '"':
return string(p[:j]), s[i+1:] return string(p[:j]), s[i+1:]
default: default:
p[j] = b p[j] = b
j += 1 j++
} }
} }
return "", "" return "", ""

View File

@ -9,6 +9,7 @@ import (
// SentPacketHandler handles ACKs received for outgoing packets // SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface { type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet) error SentPacket(packet *Packet) error
ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error ReceivedAck(ackFrame *frames.AckFrame, withPacketNumber protocol.PacketNumber, recvTime time.Time) error
@ -26,5 +27,6 @@ type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error ReceivedStopWaiting(*frames.StopWaitingFrame) error
GetAlarmTimeout() time.Time
GetAckFrame() *frames.AckFrame GetAckFrame() *frames.AckFrame
} }

View File

@ -8,13 +8,6 @@ import (
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
) )
var (
// ErrDuplicatePacket occurres when a duplicate packet is received
ErrDuplicatePacket = errors.New("ReceivedPacketHandler: Duplicate Packet")
// ErrPacketSmallerThanLastStopWaiting occurs when a packet arrives with a packet number smaller than the largest LeastUnacked of a StopWaitingFrame. If this error occurs, the packet should be ignored
ErrPacketSmallerThanLastStopWaiting = errors.New("ReceivedPacketHandler: Packet number smaller than highest StopWaiting")
)
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number") var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
type receivedPacketHandler struct { type receivedPacketHandler struct {
@ -30,20 +23,14 @@ type receivedPacketHandler struct {
retransmittablePacketsReceivedSinceLastAck int retransmittablePacketsReceivedSinceLastAck int
ackQueued bool ackQueued bool
ackAlarm time.Time ackAlarm time.Time
ackAlarmResetCallback func(time.Time)
lastAck *frames.AckFrame lastAck *frames.AckFrame
} }
// NewReceivedPacketHandler creates a new receivedPacketHandler // NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler { func NewReceivedPacketHandler() ReceivedPacketHandler {
// create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182
timer := time.NewTimer(0)
<-timer.C
return &receivedPacketHandler{ return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(), packetHistory: newReceivedPacketHistory(),
ackAlarmResetCallback: ackAlarmResetCallback, ackSendDelay: protocol.AckSendDelay,
ackSendDelay: protocol.AckSendDelay,
} }
} }
@ -52,19 +39,10 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return errInvalidPacketNumber return errInvalidPacketNumber
} }
// if the packet number is smaller than the largest LeastUnacked value of a StopWaiting we received, we cannot detect if this packet has a duplicate number if packetNumber > h.ignorePacketsBelow {
// the packet has to be ignored anyway if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
if packetNumber <= h.ignorePacketsBelow { return err
return ErrPacketSmallerThanLastStopWaiting }
}
if h.packetHistory.IsDuplicate(packetNumber) {
return ErrDuplicatePacket
}
err := h.packetHistory.ReceivedPacket(packetNumber)
if err != nil {
return err
} }
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
@ -89,7 +67,6 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame)
} }
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) { func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
var ackAlarmSet bool
h.packetsReceivedSinceLastAck++ h.packetsReceivedSinceLastAck++
if shouldInstigateAck { if shouldInstigateAck {
@ -124,7 +101,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
} else { } else {
if h.ackAlarm.IsZero() { if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay) h.ackAlarm = time.Now().Add(h.ackSendDelay)
ackAlarmSet = true
} }
} }
} }
@ -132,11 +108,6 @@ func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber
if h.ackQueued { if h.ackQueued {
// cancel the ack alarm // cancel the ack alarm
h.ackAlarm = time.Time{} h.ackAlarm = time.Time{}
ackAlarmSet = false
}
if ackAlarmSet {
h.ackAlarmResetCallback(h.ackAlarm)
} }
} }
@ -164,3 +135,5 @@ func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
return ack return ack
} }
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@ -2,9 +2,9 @@ package ackhandler
import ( import (
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type receivedPacketHistory struct { type receivedPacketHistory struct {

View File

@ -0,0 +1,38 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/frames"
)
// Returns a new slice with all non-retransmittable frames deleted.
func stripNonRetransmittableFrames(fs []frames.Frame) []frames.Frame {
res := make([]frames.Frame, 0, len(fs))
for _, f := range fs {
if IsFrameRetransmittable(f) {
res = append(res, f)
}
}
return res
}
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f frames.Frame) bool {
switch f.(type) {
case *frames.StopWaitingFrame:
return false
case *frames.AckFrame:
return false
default:
return true
}
}
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
func HasRetransmittableFrames(fs []frames.Frame) bool {
for _, f := range fs {
if IsFrameRetransmittable(f) {
return true
}
}
return false
}

View File

@ -7,9 +7,9 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
const ( const (
@ -106,26 +106,27 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error {
} }
} }
now := time.Now()
packet.SendTime = now
if packet.Length == 0 {
return errors.New("SentPacketHandler: packet cannot be empty")
}
h.bytesInFlight += packet.Length
h.lastSentPacketNumber = packet.PacketNumber h.lastSentPacketNumber = packet.PacketNumber
h.packetHistory.PushBack(*packet) now := time.Now()
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
packet.SendTime = now
h.bytesInFlight += packet.Length
h.packetHistory.PushBack(*packet)
}
h.congestion.OnPacketSent( h.congestion.OnPacketSent(
now, now,
h.bytesInFlight, h.bytesInFlight,
packet.PacketNumber, packet.PacketNumber,
packet.Length, packet.Length,
true, /* TODO: is retransmittable */ isRetransmittable,
) )
h.updateLossDetectionAlarm() h.updateLossDetectionAlarm()
return nil return nil
} }
@ -310,10 +311,11 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 { if len(h.retransmissionQueue) == 0 {
return nil return nil
} }
queueLen := len(h.retransmissionQueue) packet := h.retransmissionQueue[0]
// packets are usually NACKed in descending order. So use the slice as a stack // Shift the slice and don't retain anything that isn't needed.
packet := h.retransmissionQueue[queueLen-1] copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue = h.retransmissionQueue[:queueLen-1] h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet return packet
} }
@ -333,7 +335,11 @@ func (h *sentPacketHandler) SendingAllowed() bool {
h.bytesInFlight, h.bytesInFlight,
h.congestion.GetCongestionWindow()) h.congestion.GetCongestionWindow())
} }
return !(congestionLimited || maxTrackedLimited) // Workaround for #555:
// Always allow sending of retransmissions. This should probably be limited
// to RTOs, but we currently don't have a nice way of distinguishing them.
haveRetransmissions := len(h.retransmissionQueue) > 0
return !maxTrackedLimited && (!congestionLimited || haveRetransmissions)
} }
func (h *sentPacketHandler) retransmitOldestTwoPackets() { func (h *sentPacketHandler) retransmitOldestTwoPackets() {

View File

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -9,9 +10,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type client struct { type client struct {
@ -24,6 +25,7 @@ type client struct {
errorChan chan struct{} errorChan chan struct{}
handshakeChan <-chan handshakeEvent handshakeChan <-chan handshakeEvent
tlsConf *tls.Config
config *Config config *Config
versionNegotiated bool // has version negotiation completed yet versionNegotiated bool // has version negotiation completed yet
@ -39,7 +41,7 @@ var (
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) { func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -48,12 +50,16 @@ func DialAddr(addr string, config *Config) (Session, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Dial(udpConn, udpAddr, addr, config) return Dial(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialAddrNonFWSecure establishes a new QUIC connection to a server. // DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) { func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -62,20 +68,33 @@ func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return DialNonFWSecure(udpConn, udpAddr, addr, config) return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
} }
// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. // DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) { func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID() connID, err := utils.GenerateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hostname, _, err := net.SplitHostPort(host) var hostname string
if err != nil { if tlsConf != nil {
return nil, err hostname = tlsConf.ServerName
}
if hostname == "" {
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
} }
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
@ -83,6 +102,7 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID, connectionID: connID,
hostname: hostname, hostname: hostname,
tlsConf: tlsConf,
config: clientConfig, config: clientConfig,
version: clientConfig.Versions[0], version: clientConfig.Versions[0],
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
@ -93,15 +113,21 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
return nil, err return nil, err
} }
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", hostname, c.conn.RemoteAddr().String(), c.connectionID, c.version) utils.Infof("Starting new connection to %s (%s -> %s), connectionID %x, version %d", hostname, c.conn.LocalAddr().String(), c.conn.RemoteAddr().String(), c.connectionID, c.version)
return c.session.(NonFWSession), c.establishSecureConnection() return c.session.(NonFWSession), c.establishSecureConnection()
} }
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI. // The host parameter is used for SNI.
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { func Dial(
sess, err := DialNonFWSecure(pconn, remoteAddr, host, config) pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,16 +138,38 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return sess, nil return sess, nil
} }
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config { func populateClientConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions versions := config.Versions
if len(versions) == 0 { if len(versions) == 0 {
versions = protocol.SupportedVersions versions = protocol.SupportedVersions
} }
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
}
return &Config{ return &Config{
TLSConfig: config.TLSConfig, Versions: versions,
Versions: versions, HandshakeTimeout: handshakeTimeout,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation, RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
KeepAlive: config.KeepAlive,
} }
} }
@ -163,31 +211,46 @@ func (c *client) listen() {
} }
data = data[:n] data = data[:n]
err = c.handlePacket(addr, data) c.handlePacket(addr, data)
if err != nil {
utils.Errorf("error handling packet: %s", err.Error())
c.session.Close(err)
break
}
} }
} }
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now() rcvTime := time.Now()
r := bytes.NewReader(packet) r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer) hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
if err != nil { if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error()) utils.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
// drop this packet if we can't parse the Public Header
return
} }
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
if hdr.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID {
utils.Infof("Received a spoofed Public Reset. Ignoring.")
return
}
pr, err := parsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
return
}
utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber)
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber)))
return
}
// ignore delayed / duplicated version negotiation packets // ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag { if c.versionNegotiated && hdr.VersionFlag {
return nil return
} }
// this is the first packet after the client sent a packet with the VersionFlag set // this is the first packet after the client sent a packet with the VersionFlag set
@ -198,7 +261,10 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
if hdr.VersionFlag { if hdr.VersionFlag {
// version negotiation packets have no payload // version negotiation packets have no payload
return c.handlePacketWithVersionFlag(hdr) if err := c.handlePacketWithVersionFlag(hdr); err != nil {
c.session.Close(err)
}
return
} }
c.session.handlePacket(&receivedPacket{ c.session.handlePacket(&receivedPacket{
@ -207,7 +273,6 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
data: packet[len(packet)-r.Len():], data: packet[len(packet)-r.Len():],
rcvTime: rcvTime, rcvTime: rcvTime,
}) })
return nil
} }
func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error { func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
@ -246,6 +311,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.hostname, c.hostname,
c.version, c.version,
c.connectionID, c.connectionID,
c.tlsConf,
c.config, c.config,
negotiatedVersions, negotiatedVersions,
) )

View File

@ -4,8 +4,8 @@ import (
"math" "math"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// This cubic implementation is based on the one found in Chromiums's QUIC // This cubic implementation is based on the one found in Chromiums's QUIC

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
const ( const (

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// Note(pwestin): the magic clamping numbers come from the original code in // Note(pwestin): the magic clamping numbers come from the original code in

View File

@ -3,8 +3,8 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937 // PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937

View File

@ -3,7 +3,7 @@ package congestion
import ( import (
"time" "time"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
const ( const (

View File

@ -102,3 +102,12 @@ func (cc *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
// If nothing matches, return the first certificate. // If nothing matches, return the first certificate.
return &c.Certificates[0], nil return &c.Certificates[0], nil
} }
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type entryType uint8 type entryType uint8

View File

@ -107,15 +107,14 @@ func (c *certManager) Verify(hostname string) error {
var opts x509.VerifyOptions var opts x509.VerifyOptions
if c.config != nil { if c.config != nil {
opts.Roots = c.config.RootCAs opts.Roots = c.config.RootCAs
opts.DNSName = c.config.ServerName
if c.config.Time == nil { if c.config.Time == nil {
opts.CurrentTime = time.Now() opts.CurrentTime = time.Now()
} else { } else {
opts.CurrentTime = c.config.Time() opts.CurrentTime = c.config.Time()
} }
} else {
opts.DNSName = hostname
} }
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
opts.DNSName = hostname
// the first certificate is the leaf certificate, all others are intermediates // the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 { if len(c.chain) > 1 {

View File

@ -1,14 +0,0 @@
// +build go1.8
package crypto
import "crypto/tls"
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
return c.GetConfigForClient(&tls.ClientHelloInfo{
ServerName: sni,
})
}

View File

@ -1,9 +0,0 @@
// +build !go1.8
package crypto
import "crypto/tls"
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
return c, nil
}

View File

@ -5,8 +5,8 @@ import (
"crypto/sha256" "crypto/sha256"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
) )

View File

@ -8,7 +8,7 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
func main() { func main() {
@ -24,7 +24,7 @@ func main() {
utils.SetLogTimeFormat("") utils.SetLogTimeFormat("")
hclient := &http.Client{ hclient := &http.Client{
Transport: &h2quic.QuicRoundTripper{}, Transport: &h2quic.RoundTripper{},
} }
var wg sync.WaitGroup var wg sync.WaitGroup

View File

@ -31,10 +31,7 @@ func main() {
// Start a server that echos all data on the first stream opened by the client // Start a server that echos all data on the first stream opened by the client
func echoServer() error { func echoServer() error {
cfgServer := &quic.Config{ listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
TLSConfig: generateTLSConfig(),
}
listener, err := quic.ListenAddr(addr, cfgServer)
if err != nil { if err != nil {
return err return err
} }
@ -52,10 +49,7 @@ func echoServer() error {
} }
func clientMain() error { func clientMain() error {
cfgClient := &quic.Config{ session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
session, err := quic.DialAddr(addr, cfgClient)
if err != nil { if err != nil {
return err return err
} }

View File

@ -18,7 +18,7 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"github.com/lucas-clemente/quic-go/h2quic" "github.com/lucas-clemente/quic-go/h2quic"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
type binds []string type binds []string

View File

@ -7,9 +7,9 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowControlManager struct { type flowControlManager struct {
@ -78,7 +78,7 @@ func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset
if streamFlowController.ContributesToConnection() { if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment) f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() { if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, f.connFlowController.receiveWindow)) return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
} }
} }
@ -107,7 +107,7 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b
if streamFlowController.ContributesToConnection() { if streamFlowController.ContributesToConnection() {
f.connFlowController.IncrementHighestReceived(increment) f.connFlowController.IncrementHighestReceived(increment)
if f.connFlowController.CheckFlowControlViolation() { if f.connFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, f.connFlowController.receiveWindow)) return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", f.connFlowController.highestReceived, f.connFlowController.receiveWindow))
} }
} }
@ -157,6 +157,11 @@ func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (proto
f.mutex.RLock() f.mutex.RLock()
defer f.mutex.RUnlock() defer f.mutex.RUnlock()
// StreamID can be 0 when retransmitting
if streamID == 0 {
return f.connFlowController.receiveWindow, nil
}
flowController, err := f.getFlowController(streamID) flowController, err := f.getFlowController(streamID)
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -6,8 +6,8 @@ import (
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowController struct { type flowController struct {

View File

@ -5,8 +5,8 @@ import (
"errors" "errors"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A BlockedFrame in QUIC // A BlockedFrame in QUIC

View File

@ -6,9 +6,9 @@ import (
"io" "io"
"math" "math"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A ConnectionCloseFrame in QUIC // A ConnectionCloseFrame in QUIC

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"io" "io"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A GoawayFrame is a GOAWAY frame // A GoawayFrame is a GOAWAY frame

View File

@ -1,6 +1,6 @@
package frames package frames
import "github.com/lucas-clemente/quic-go/utils" import "github.com/lucas-clemente/quic-go/internal/utils"
// LogFrame logs a frame, either sent or received // LogFrame logs a frame, either sent or received
func LogFrame(frame Frame, sent bool) { func LogFrame(frame Frame, sent bool) {

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A RstStreamFrame in QUIC // A RstStreamFrame in QUIC

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A StopWaitingFrame in QUIC // A StopWaitingFrame in QUIC

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A StreamFrame of QUIC // A StreamFrame of QUIC

View File

@ -3,8 +3,8 @@ package frames
import ( import (
"bytes" "bytes"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A WindowUpdateFrame in QUIC // A WindowUpdateFrame in QUIC

View File

@ -15,59 +15,72 @@ import (
"golang.org/x/net/idna" "golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// Client is a HTTP2 client doing QUIC requests type roundTripperOpts struct {
type Client struct { DisableCompression bool
}
var dialAddr = quic.DialAddr
// client is a HTTP2 client doing QUIC requests
type client struct {
mutex sync.RWMutex mutex sync.RWMutex
dialAddr func(hostname string, config *quic.Config) (quic.Session, error) tlsConf *tls.Config
config *quic.Config config *quic.Config
opts *roundTripperOpts
t *QuicRoundTripper
hostname string hostname string
encryptionLevel protocol.EncryptionLevel encryptionLevel protocol.EncryptionLevel
handshakeErr error handshakeErr error
dialChan chan struct{} // will be closed once the handshake is complete and the header stream has been opened dialOnce sync.Once
session quic.Session session quic.Session
headerStream quic.Stream headerStream quic.Stream
headerErr *qerr.QuicError headerErr *qerr.QuicError
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response responses map[protocol.StreamID]chan *http.Response
} }
var _ h2quicClient = &Client{} var _ http.RoundTripper = &client{}
// NewClient creates a new client var defaultQuicConfig = &quic.Config{
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) *Client { RequestConnectionIDTruncation: true,
return &Client{ KeepAlive: true,
t: t, }
dialAddr: quic.DialAddr,
// newClient creates a new client
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname), hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response), responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted, encryptionLevel: protocol.EncryptionUnencrypted,
config: &quic.Config{ tlsConf: tlsConfig,
TLSConfig: tlsConfig, config: config,
RequestConnectionIDTruncation: true, opts: opts,
}, headerErrored: make(chan struct{}),
dialChan: make(chan struct{}),
} }
} }
// Dial dials the connection // dial dials the connection
func (c *Client) Dial() (err error) { func (c *client) dial() error {
defer func() { var err error
c.handshakeErr = err c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
close(c.dialChan)
}()
c.session, err = c.dialAddr(c.hostname, c.config)
if err != nil { if err != nil {
return err return err
} }
@ -82,10 +95,10 @@ func (c *Client) Dial() (err error) {
} }
c.requestWriter = newRequestWriter(c.headerStream) c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream() go c.handleHeaderStream()
return return nil
} }
func (c *Client) handleHeaderStream() { func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream) h2framer := http2.NewFramer(nil, c.headerStream)
@ -111,7 +124,7 @@ func (c *Client) handleHeaderStream() {
} }
c.mutex.RLock() c.mutex.RLock()
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock() c.mutex.RUnlock()
if !ok { if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream)) c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
@ -122,41 +135,38 @@ func (c *Client) handleHeaderStream() {
if err != nil { if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error()) c.headerErr = qerr.Error(qerr.InternalError, err.Error())
} }
headerChan <- rsp responseChan <- rsp
} }
// stop all running request // stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error()) utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
c.mutex.Lock() close(c.headerErrored)
for _, responseChan := range c.responses {
close(responseChan)
}
c.mutex.Unlock()
} }
// Do executes a request and returns a response // Roundtrip executes a request and returns a response
func (c *Client) Do(req *http.Request) (*http.Response, error) { func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one // TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" { if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme") return nil, errors.New("quic http2: unsupported scheme")
} }
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
utils.Debugf("%s vs %s", req.Host, c.hostname) return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
return nil, errors.New("h2quic Client BUG: Do called for the wrong client")
} }
hasBody := (req.Body != nil) c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
// wait until the handshake is complete
<-c.dialChan
if c.handshakeErr != nil { if c.handshakeErr != nil {
return nil, c.handshakeErr return nil, c.handshakeErr
} }
hasBody := (req.Body != nil)
responseChan := make(chan *http.Response) responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync() dataStream, err := c.session.OpenStreamSync()
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
c.mutex.Lock() c.mutex.Lock()
@ -164,14 +174,14 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
c.mutex.Unlock() c.mutex.Unlock()
var requestedGzip bool var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true requestedGzip = true
} }
// TODO: add support for trailers // TODO: add support for trailers
endStream := !hasBody endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil { if err != nil {
c.Close(err) _ = c.CloseWithError(err)
return nil, err return nil, err
} }
@ -198,15 +208,15 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
c.mutex.Lock() c.mutex.Lock()
delete(c.responses, dataStream.StreamID()) delete(c.responses, dataStream.StreamID())
c.mutex.Unlock() c.mutex.Unlock()
if res == nil { // an error occured on the header stream
c.Close(c.headerErr)
return nil, c.headerErr
}
case err := <-resc: case err := <-resc:
bodySent = true bodySent = true
if err != nil { if err != nil {
return nil, err return nil, err
} }
case <-c.headerErrored:
// an error occured on the header stream
_ = c.CloseWithError(c.headerErr)
return nil, c.headerErr
} }
} }
@ -230,11 +240,10 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
} }
res.Request = req res.Request = req
return res, nil return res, nil
} }
func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) { func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
defer func() { defer func() {
cerr := body.Close() cerr := body.Close()
if err == nil { if err == nil {
@ -252,8 +261,15 @@ func (c *Client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
} }
// Close closes the client // Close closes the client
func (c *Client) Close(e error) { func (c *client) CloseWithError(e error) error {
_ = c.session.Close(e) if c.session == nil {
return nil
}
return c.session.Close(e)
}
func (c *client) Close() error {
return c.CloseWithError(nil)
} }
// copied from net/transport.go // copied from net/transport.go

View File

@ -13,8 +13,8 @@ import (
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type requestWriter struct { type requestWriter struct {

View File

@ -8,8 +8,8 @@ import (
"sync" "sync"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )

View File

@ -4,20 +4,23 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/lex/httplex" "golang.org/x/net/lex/httplex"
) )
type h2quicClient interface { type roundTripCloser interface {
Dial() error http.RoundTripper
Do(*http.Request) (*http.Response, error) io.Closer
} }
// QuicRoundTripper implements the http.RoundTripper interface // RoundTripper implements the http.RoundTripper interface
type QuicRoundTripper struct { type RoundTripper struct {
mutex sync.Mutex mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from // DisableCompression, if true, prevents the Transport from
@ -34,13 +37,29 @@ type QuicRoundTripper struct {
// tls.Client. If nil, the default configuration is used. // tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
clients map[string]h2quicClient // QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
clients map[string]roundTripCloser
} }
var _ http.RoundTripper = &QuicRoundTripper{} // RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may
// create a new QUIC connection. If set true and
// no cached connection is available, RoundTrip
// will return ErrNoCachedConn.
OnlyCachedConn bool
}
// RoundTrip does a round trip var _ roundTripCloser = &RoundTripper{}
func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil { if req.URL == nil {
closeRequestBody(req) closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL") return nil, errors.New("quic: nil Request.URL")
@ -76,35 +95,48 @@ func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
} }
hostname := authorityAddr("https", hostnameFromRequest(req)) hostname := authorityAddr("https", hostnameFromRequest(req))
client, err := r.getClient(hostname) cl, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return client.Do(req) return cl.RoundTrip(req)
} }
func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) { // RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.clients == nil { if r.clients == nil {
r.clients = make(map[string]h2quicClient) r.clients = make(map[string]roundTripCloser)
} }
client, ok := r.clients[hostname] client, ok := r.clients[hostname]
if !ok { if !ok {
client = NewClient(r, r.TLSClientConfig, hostname) if onlyCached {
err := client.Dial() return nil, ErrNoCachedConn
if err != nil {
return nil, err
} }
client = newClient(hostname, r.TLSClientConfig, &roundTripperOpts{DisableCompression: r.DisableCompression}, r.QuicConfig)
r.clients[hostname] = client r.clients[hostname] = client
} }
return client, nil return client, nil
} }
func (r *QuicRoundTripper) disableCompression() bool { // Close closes the QUIC connections that this RoundTripper has used
return r.DisableCompression func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
return nil
} }
func closeRequestBody(req *http.Request) { func closeRequestBody(req *http.Request) {

View File

@ -13,9 +13,9 @@ import (
"time" "time"
quic "github.com/lucas-clemente/quic-go" quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )
@ -29,10 +29,20 @@ type remoteCloser interface {
CloseRemote(protocol.ByteCount) CloseRemote(protocol.ByteCount)
} }
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = quic.Listen
quicListenAddr = quic.ListenAddr
)
// Server is a HTTP2 server listening for QUIC connections. // Server is a HTTP2 server listening for QUIC connections.
type Server struct { type Server struct {
*http.Server *http.Server
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
// If nil, it uses reasonable default values.
QuicConfig *quic.Config
// Private flag for demo, do not use // Private flag for demo, do not use
CloseAfterFirstRequest bool CloseAfterFirstRequest bool
@ -69,11 +79,11 @@ func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
} }
// Serve an existing UDP connection. // Serve an existing UDP connection.
func (s *Server) Serve(conn *net.UDPConn) error { func (s *Server) Serve(conn net.PacketConn) error {
return s.serveImpl(s.TLSConfig, conn) return s.serveImpl(s.TLSConfig, conn)
} }
func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil { if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server") return errors.New("use of h2quic.Server without http.Server")
} }
@ -83,17 +93,12 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
return errors.New("ListenAndServe may only be called once") return errors.New("ListenAndServe may only be called once")
} }
config := quic.Config{
TLSConfig: tlsConfig,
Versions: protocol.SupportedVersions,
}
var ln quic.Listener var ln quic.Listener
var err error var err error
if conn == nil { if conn == nil {
ln, err = quic.ListenAddr(s.Addr, &config) ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
} else { } else {
ln, err = quic.Listen(conn, &config) ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
} }
if err != nil { if err != nil {
s.listenerMutex.Unlock() s.listenerMutex.Unlock()
@ -255,7 +260,6 @@ func (s *Server) CloseGracefully(timeout time.Duration) error {
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC. // SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443): // The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
// Alternate-Protocol: 443:quic
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30" // Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error { func (s *Server) SetQuicHeaders(hdr http.Header) error {
port := atomic.LoadUint32(&s.port) port := atomic.LoadUint32(&s.port)
@ -283,7 +287,6 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
} }
} }
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port))
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString)) hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
return nil return nil

View File

@ -5,9 +5,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// ConnectionParametersManager negotiates and stores the connection parameters // ConnectionParametersManager negotiates and stores the connection parameters
@ -50,6 +50,8 @@ type connectionParametersManager struct {
sendConnectionFlowControlWindow protocol.ByteCount sendConnectionFlowControlWindow protocol.ByteCount
receiveStreamFlowControlWindow protocol.ByteCount receiveStreamFlowControlWindow protocol.ByteCount
receiveConnectionFlowControlWindow protocol.ByteCount receiveConnectionFlowControlWindow protocol.ByteCount
maxReceiveStreamFlowControlWindow protocol.ByteCount
maxReceiveConnectionFlowControlWindow protocol.ByteCount
} }
var _ ConnectionParametersManager = &connectionParametersManager{} var _ ConnectionParametersManager = &connectionParametersManager{}
@ -61,14 +63,19 @@ var (
) )
// NewConnectionParamatersManager creates a new connection parameters manager // NewConnectionParamatersManager creates a new connection parameters manager
func NewConnectionParamatersManager(pers protocol.Perspective, v protocol.VersionNumber) ConnectionParametersManager { func NewConnectionParamatersManager(
pers protocol.Perspective, v protocol.VersionNumber,
maxReceiveStreamFlowControlWindow protocol.ByteCount, maxReceiveConnectionFlowControlWindow protocol.ByteCount,
) ConnectionParametersManager {
h := &connectionParametersManager{ h := &connectionParametersManager{
perspective: pers, perspective: pers,
version: v, version: v,
sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client
sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client
receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, receiveStreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, receiveConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
maxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
maxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
} }
if h.perspective == protocol.PerspectiveServer { if h.perspective == protocol.PerspectiveServer {
@ -207,10 +214,7 @@ func (h *connectionParametersManager) GetReceiveStreamFlowControlWindow() protoc
// GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data // GetMaxReceiveStreamFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount { func (h *connectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
if h.perspective == protocol.PerspectiveServer { return h.maxReceiveStreamFlowControlWindow
return protocol.MaxReceiveStreamFlowControlWindowServer
}
return protocol.MaxReceiveStreamFlowControlWindowClient
} }
// GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data // GetReceiveConnectionFlowControlWindow gets the size of the stream-level flow control window for receiving data
@ -222,10 +226,7 @@ func (h *connectionParametersManager) GetReceiveConnectionFlowControlWindow() pr
// GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data // GetMaxReceiveConnectionFlowControlWindow gets the maximum size of the stream-level flow control window for sending data
func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount { func (h *connectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
if h.perspective == protocol.PerspectiveServer { return h.maxReceiveConnectionFlowControlWindow
return protocol.MaxReceiveConnectionFlowControlWindowServer
}
return protocol.MaxReceiveConnectionFlowControlWindowClient
} }
// GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection // GetMaxOutgoingStreams gets the maximum number of outgoing streams per connection

View File

@ -12,9 +12,9 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type cryptoSetupClient struct { type cryptoSetupClient struct {
@ -332,7 +332,6 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil { if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.sealForwardSecure return protocol.EncryptionForwardSecure, h.sealForwardSecure
} else if h.secureAEAD != nil { } else if h.secureAEAD != nil {
@ -342,6 +341,10 @@ func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
} }
} }
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.sealUnencrypted
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()

View File

@ -10,9 +10,9 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// KeyDerivationFunction is used for key derivation // KeyDerivationFunction is used for key derivation
@ -214,12 +214,16 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) { func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
if h.forwardSecureAEAD != nil && h.sentSHLO {
return protocol.EncryptionForwardSecure, h.sealForwardSecure return protocol.EncryptionForwardSecure, h.sealForwardSecure
} else if h.secureAEAD != nil { }
// secureAEAD and forwardSecureAEAD are created at the same time (when receiving the CHLO) return protocol.EncryptionUnencrypted, h.sealUnencrypted
// make sure that the SHLO isn't sent forward-secure }
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.sealSecure return protocol.EncryptionSecure, h.sealSecure
} }
return protocol.EncryptionUnencrypted, h.sealUnencrypted return protocol.EncryptionUnencrypted, h.sealUnencrypted
@ -251,7 +255,6 @@ func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protoc
} }
func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
h.sentSHLO = true
return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) return h.secureAEAD.Seal(dst, src, packetNumber, associatedData)
} }

View File

@ -5,8 +5,8 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (

View File

@ -7,9 +7,9 @@ import (
"io" "io"
"sort" "sort"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// A HandshakeMessage is a handshake message // A HandshakeMessage is a handshake message

View File

@ -15,6 +15,7 @@ type CryptoSetup interface {
GetSealer() (protocol.EncryptionLevel, Sealer) GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
} }
// TransportParameters are parameters sent to the peer during the handshake // TransportParameters are parameters sent to the peer during the handshake

View File

@ -8,8 +8,8 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type serverConfigClient struct { type serverConfigClient struct {

View File

@ -0,0 +1,133 @@
package handshaketests
import (
"crypto/tls"
"fmt"
"net"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/integrationtests/proxy"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/testdata"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Handshake integration tets", func() {
var (
proxy *quicproxy.QuicProxy
server quic.Listener
serverConfig *quic.Config
testStartedAt time.Time
)
rtt := 400 * time.Millisecond
BeforeEach(func() {
serverConfig = &quic.Config{}
})
AfterEach(func() {
Expect(proxy.Close()).To(Succeed())
Expect(server.Close()).To(Succeed())
})
runServerAndProxy := func() {
var err error
// start the server
server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
// start the proxy
proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{
RemoteAddr: server.Addr().String(),
DelayPacket: func(_ quicproxy.Direction, _ protocol.PacketNumber) time.Duration { return rtt / 2 },
})
Expect(err).ToNot(HaveOccurred())
testStartedAt = time.Now()
go func() {
for {
_, _ = server.Accept()
}
}()
}
expectDurationInRTTs := func(num int) {
testDuration := time.Since(testStartedAt)
expectedDuration := time.Duration(num) * rtt
Expect(testDuration).To(SatisfyAll(
BeNumerically(">=", expectedDuration),
BeNumerically("<", expectedDuration+rtt),
))
}
It("fails when there's no matching version, after 1 RTT", func() {
Expect(len(protocol.SupportedVersions)).To(BeNumerically(">", 1))
serverConfig.Versions = protocol.SupportedVersions[:1]
runServerAndProxy()
clientConfig := &quic.Config{
Versions: protocol.SupportedVersions[1:2],
}
_, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig)
Expect(err).To(HaveOccurred())
Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion))
expectDurationInRTTs(1)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// 1 RTT to become forward-secure
It("is forward-secure after 3 RTTs", func() {
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(3)
})
// 1 RTT for verifying the source address
// 1 RTT to become secure
// TODO (marten-seemann): enable this test (see #625)
PIt("is secure after 2 RTTs", func() {
utils.SetLogLevel(utils.LogLevelDebug)
runServerAndProxy()
_, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
fmt.Println("#### is non fw secure ###")
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("is forward-secure after 2 RTTs when the server doesn't require an STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return true
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
expectDurationInRTTs(2)
})
It("doesn't complete the handshake when the server never accepts the STK", func() {
serverConfig.AcceptSTK = func(_ net.Addr, _ *quic.STK) bool {
return false
}
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects))
})
It("doesn't complete the handshake when the handshake timeout is too short", func() {
serverConfig.HandshakeTimeout = 2 * rtt
runServerAndProxy()
_, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).To(HaveOccurred())
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
// 2 RTTs during the timeout
// plus 1 RTT: the timer starts 0.5 RTTs after sending the first packet, and the CONNECTION_CLOSE needs another 0.5 RTTs to reach the client
expectDurationInRTTs(3)
})
})

View File

@ -1,7 +1,6 @@
package quic package quic
import ( import (
"crypto/tls"
"io" "io"
"net" "net"
"time" "time"
@ -11,12 +10,32 @@ import (
// Stream is the interface implemented by QUIC streams // Stream is the interface implemented by QUIC streams
type Stream interface { type Stream interface {
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
io.Reader io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
io.Writer io.Writer
io.Closer io.Closer
StreamID() protocol.StreamID StreamID() protocol.StreamID
// Reset closes the stream with an error. // Reset closes the stream with an error.
Reset(error) Reset(error)
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
SetWriteDeadline(t time.Time) error
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
SetDeadline(t time.Time) error
} }
// A Session is a QUIC connection between two peers. // A Session is a QUIC connection between two peers.
@ -37,6 +56,9 @@ type Session interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent. // Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error Close(error) error
// WaitUntilClosed() blocks until the session is closed.
// Warning: This API should not be considered stable and might change soon.
WaitUntilClosed()
} }
// A NonFWSession is a QUIC connection between two peers half-way through the handshake. // A NonFWSession is a QUIC connection between two peers half-way through the handshake.
@ -61,21 +83,31 @@ type STK struct {
// Config contains all configuration data needed for a QUIC server or client. // Config contains all configuration data needed for a QUIC server or client.
// More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441. // More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441.
type Config struct { type Config struct {
TLSConfig *tls.Config
// The QUIC versions that can be negotiated. // The QUIC versions that can be negotiated.
// If not set, it uses all versions available. // If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon. // Warning: This API should not be considered stable and will change soon.
Versions []protocol.VersionNumber Versions []protocol.VersionNumber
// Ask the server to truncate the connection ID sent in the Public Header. // Ask the server to truncate the connection ID sent in the Public Header.
// If not set, the default checks if
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated. // This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client. // Currently only valid for the client.
RequestConnectionIDTruncation bool RequestConnectionIDTruncation bool
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
HandshakeTimeout time.Duration
// AcceptSTK determines if an STK is accepted. // AcceptSTK determines if an STK is accepted.
// It is called with stk = nil if the client didn't send an STK. // It is called with stk = nil if the client didn't send an STK.
// If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours // If not set, it verifies that the address matches, and that the STK was issued within the last 24 hours.
// This option is only valid for the server. // This option is only valid for the server.
AcceptSTK func(clientAddr net.Addr, stk *STK) bool AcceptSTK func(clientAddr net.Addr, stk *STK) bool
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
MaxReceiveStreamFlowControlWindow protocol.ByteCount
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow protocol.ByteCount
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
} }
// A Listener for incoming QUIC connections // A Listener for incoming QUIC connections

View File

@ -0,0 +1,153 @@
// Automatically generated by MockGen. DO NOT EDIT!
// Source: github.com/lucas-clemente/quic-go/handshake (interfaces: ConnectionParametersManager)
package mocks
import (
gomock "github.com/golang/mock/gomock"
handshake "github.com/lucas-clemente/quic-go/handshake"
protocol "github.com/lucas-clemente/quic-go/protocol"
time "time"
)
// Mock of ConnectionParametersManager interface
type MockConnectionParametersManager struct {
ctrl *gomock.Controller
recorder *_MockConnectionParametersManagerRecorder
}
// Recorder for MockConnectionParametersManager (not exported)
type _MockConnectionParametersManagerRecorder struct {
mock *MockConnectionParametersManager
}
func NewMockConnectionParametersManager(ctrl *gomock.Controller) *MockConnectionParametersManager {
mock := &MockConnectionParametersManager{ctrl: ctrl}
mock.recorder = &_MockConnectionParametersManagerRecorder{mock}
return mock
}
func (_m *MockConnectionParametersManager) EXPECT() *_MockConnectionParametersManagerRecorder {
return _m.recorder
}
func (_m *MockConnectionParametersManager) GetHelloMap() (map[handshake.Tag][]byte, error) {
ret := _m.ctrl.Call(_m, "GetHelloMap")
ret0, _ := ret[0].(map[handshake.Tag][]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockConnectionParametersManagerRecorder) GetHelloMap() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetHelloMap")
}
func (_m *MockConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration {
ret := _m.ctrl.Call(_m, "GetIdleConnectionStateLifetime")
ret0, _ := ret[0].(time.Duration)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetIdleConnectionStateLifetime() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetIdleConnectionStateLifetime")
}
func (_m *MockConnectionParametersManager) GetMaxIncomingStreams() uint32 {
ret := _m.ctrl.Call(_m, "GetMaxIncomingStreams")
ret0, _ := ret[0].(uint32)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxIncomingStreams() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxIncomingStreams")
}
func (_m *MockConnectionParametersManager) GetMaxOutgoingStreams() uint32 {
ret := _m.ctrl.Call(_m, "GetMaxOutgoingStreams")
ret0, _ := ret[0].(uint32)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxOutgoingStreams() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxOutgoingStreams")
}
func (_m *MockConnectionParametersManager) GetMaxReceiveConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetMaxReceiveConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetMaxReceiveStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetMaxReceiveStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetMaxReceiveStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetMaxReceiveStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetReceiveConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetReceiveConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetReceiveStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetReceiveStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetReceiveStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetSendConnectionFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetSendConnectionFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetSendConnectionFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendConnectionFlowControlWindow")
}
func (_m *MockConnectionParametersManager) GetSendStreamFlowControlWindow() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "GetSendStreamFlowControlWindow")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) GetSendStreamFlowControlWindow() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetSendStreamFlowControlWindow")
}
func (_m *MockConnectionParametersManager) SetFromMap(_param0 map[handshake.Tag][]byte) error {
ret := _m.ctrl.Call(_m, "SetFromMap", _param0)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) SetFromMap(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "SetFromMap", arg0)
}
func (_m *MockConnectionParametersManager) TruncateConnectionID() bool {
ret := _m.ctrl.Call(_m, "TruncateConnectionID")
ret0, _ := ret[0].(bool)
return ret0
}
func (_mr *_MockConnectionParametersManagerRecorder) TruncateConnectionID() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "TruncateConnectionID")
}

View File

@ -0,0 +1,4 @@
package mocks
//go:generate mockgen -destination mocks_fc/flow_control_manager.go -package mocks_fc github.com/lucas-clemente/quic-go/flowcontrol FlowControlManager
//go:generate mockgen -destination cpm.go -package mocks github.com/lucas-clemente/quic-go/handshake ConnectionParametersManager

View File

@ -0,0 +1,140 @@
// Automatically generated by MockGen. DO NOT EDIT!
// Source: github.com/lucas-clemente/quic-go/flowcontrol (interfaces: FlowControlManager)
package mocks_fc
import (
gomock "github.com/golang/mock/gomock"
flowcontrol "github.com/lucas-clemente/quic-go/flowcontrol"
protocol "github.com/lucas-clemente/quic-go/protocol"
)
// Mock of FlowControlManager interface
type MockFlowControlManager struct {
ctrl *gomock.Controller
recorder *_MockFlowControlManagerRecorder
}
// Recorder for MockFlowControlManager (not exported)
type _MockFlowControlManagerRecorder struct {
mock *MockFlowControlManager
}
func NewMockFlowControlManager(ctrl *gomock.Controller) *MockFlowControlManager {
mock := &MockFlowControlManager{ctrl: ctrl}
mock.recorder = &_MockFlowControlManagerRecorder{mock}
return mock
}
func (_m *MockFlowControlManager) EXPECT() *_MockFlowControlManagerRecorder {
return _m.recorder
}
func (_m *MockFlowControlManager) AddBytesRead(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "AddBytesRead", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) AddBytesRead(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "AddBytesRead", arg0, arg1)
}
func (_m *MockFlowControlManager) AddBytesSent(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "AddBytesSent", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) AddBytesSent(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "AddBytesSent", arg0, arg1)
}
func (_m *MockFlowControlManager) GetReceiveWindow(_param0 protocol.StreamID) (protocol.ByteCount, error) {
ret := _m.ctrl.Call(_m, "GetReceiveWindow", _param0)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) GetReceiveWindow(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetReceiveWindow", arg0)
}
func (_m *MockFlowControlManager) GetWindowUpdates() []flowcontrol.WindowUpdate {
ret := _m.ctrl.Call(_m, "GetWindowUpdates")
ret0, _ := ret[0].([]flowcontrol.WindowUpdate)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) GetWindowUpdates() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "GetWindowUpdates")
}
func (_m *MockFlowControlManager) NewStream(_param0 protocol.StreamID, _param1 bool) {
_m.ctrl.Call(_m, "NewStream", _param0, _param1)
}
func (_mr *_MockFlowControlManagerRecorder) NewStream(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "NewStream", arg0, arg1)
}
func (_m *MockFlowControlManager) RemainingConnectionWindowSize() protocol.ByteCount {
ret := _m.ctrl.Call(_m, "RemainingConnectionWindowSize")
ret0, _ := ret[0].(protocol.ByteCount)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) RemainingConnectionWindowSize() *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "RemainingConnectionWindowSize")
}
func (_m *MockFlowControlManager) RemoveStream(_param0 protocol.StreamID) {
_m.ctrl.Call(_m, "RemoveStream", _param0)
}
func (_mr *_MockFlowControlManagerRecorder) RemoveStream(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "RemoveStream", arg0)
}
func (_m *MockFlowControlManager) ResetStream(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "ResetStream", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) ResetStream(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "ResetStream", arg0, arg1)
}
func (_m *MockFlowControlManager) SendWindowSize(_param0 protocol.StreamID) (protocol.ByteCount, error) {
ret := _m.ctrl.Call(_m, "SendWindowSize", _param0)
ret0, _ := ret[0].(protocol.ByteCount)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) SendWindowSize(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "SendWindowSize", arg0)
}
func (_m *MockFlowControlManager) UpdateHighestReceived(_param0 protocol.StreamID, _param1 protocol.ByteCount) error {
ret := _m.ctrl.Call(_m, "UpdateHighestReceived", _param0, _param1)
ret0, _ := ret[0].(error)
return ret0
}
func (_mr *_MockFlowControlManagerRecorder) UpdateHighestReceived(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "UpdateHighestReceived", arg0, arg1)
}
func (_m *MockFlowControlManager) UpdateWindow(_param0 protocol.StreamID, _param1 protocol.ByteCount) (bool, error) {
ret := _m.ctrl.Call(_m, "UpdateWindow", _param0, _param1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (_mr *_MockFlowControlManagerRecorder) UpdateWindow(arg0, arg1 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCall(_mr.mock, "UpdateWindow", arg0, arg1)
}

View File

@ -1,26 +1,26 @@
package utils package utils
import ( import (
"fmt"
"log" "log"
"os" "os"
"strconv"
"time" "time"
) )
// LogLevel of quic-go // LogLevel of quic-go
type LogLevel uint8 type LogLevel uint8
const ( const logEnv = "QUIC_GO_LOG_LEVEL"
logEnv = "QUIC_GO_LOG_LEVEL"
// LogLevelDebug enables debug logs (e.g. packet contents) const (
LogLevelDebug LogLevel = iota // LogLevelNothing disables
// LogLevelInfo enables info logs (e.g. packets) LogLevelNothing LogLevel = iota
LogLevelInfo
// LogLevelError enables err logs // LogLevelError enables err logs
LogLevelError LogLevelError
// LogLevelNothing disables // LogLevelInfo enables info logs (e.g. packets)
LogLevelNothing LogLevelInfo
// LogLevelDebug enables debug logs (e.g. packet contents)
LogLevelDebug
) )
var ( var (
@ -49,14 +49,14 @@ func Debugf(format string, args ...interface{}) {
// Infof logs something // Infof logs something
func Infof(format string, args ...interface{}) { func Infof(format string, args ...interface{}) {
if logLevel <= LogLevelInfo { if logLevel >= LogLevelInfo {
logMessage(format, args...) logMessage(format, args...)
} }
} }
// Errorf logs something // Errorf logs something
func Errorf(format string, args ...interface{}) { func Errorf(format string, args ...interface{}) {
if logLevel <= LogLevelError { if logLevel >= LogLevelError {
logMessage(format, args...) logMessage(format, args...)
} }
} }
@ -79,13 +79,16 @@ func init() {
} }
func readLoggingEnv() { func readLoggingEnv() {
env := os.Getenv(logEnv) switch os.Getenv(logEnv) {
if env == "" { case "":
return return
case "DEBUG":
logLevel = LogLevelDebug
case "INFO":
logLevel = LogLevelInfo
case "ERROR":
logLevel = LogLevelError
default:
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging")
} }
level, err := strconv.Atoi(env)
if err != nil {
return
}
logLevel = LogLevel(level)
} }

View File

@ -0,0 +1,43 @@
package utils
import "time"
// A Timer wrapper that behaves correctly when resetting
type Timer struct {
t *time.Timer
read bool
deadline time.Time
}
// NewTimer creates a new timer that is not set
func NewTimer() *Timer {
return &Timer{t: time.NewTimer(0)}
}
// Chan returns the channel of the wrapped timer
func (t *Timer) Chan() <-chan time.Time {
return t.t.C
}
// Reset the timer, no matter whether the value was read or not
func (t *Timer) Reset(deadline time.Time) {
if deadline.Equal(t.deadline) {
// No need to reset the timer
return
}
// We need to drain the timer if the value from its channel was not read yet.
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
if !t.t.Stop() && !t.read {
<-t.t.C
}
t.t.Reset(deadline.Sub(time.Now()))
t.read = false
t.deadline = deadline
}
// SetRead should be called after the value from the chan was read
func (t *Timer) SetRead() {
t.read = true
}

View File

@ -9,7 +9,6 @@ import (
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
) )
type packedPacket struct { type packedPacket struct {
@ -24,18 +23,24 @@ type packetPacker struct {
perspective protocol.Perspective perspective protocol.Perspective
version protocol.VersionNumber version protocol.VersionNumber
cryptoSetup handshake.CryptoSetup cryptoSetup handshake.CryptoSetup
// as long as packets are not sent with forward-secure encryption, we limit the MaxPacketSize such that they can be retransmitted as a whole
isForwardSecure bool
packetNumberGenerator *packetNumberGenerator packetNumberGenerator *packetNumberGenerator
connectionParameters handshake.ConnectionParametersManager
streamFramer *streamFramer
connectionParameters handshake.ConnectionParametersManager
streamFramer *streamFramer
controlFrames []frames.Frame controlFrames []frames.Frame
stopWaiting *frames.StopWaitingFrame
ackFrame *frames.AckFrame
leastUnacked protocol.PacketNumber
} }
func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, connectionParameters handshake.ConnectionParametersManager, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber) *packetPacker { func newPacketPacker(connectionID protocol.ConnectionID,
cryptoSetup handshake.CryptoSetup,
connectionParameters handshake.ConnectionParametersManager,
streamFramer *streamFramer,
perspective protocol.Perspective,
version protocol.VersionNumber,
) *packetPacker {
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
connectionID: connectionID, connectionID: connectionID,
@ -48,109 +53,91 @@ func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.C
} }
// PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame) (*packedPacket, error) {
// in case the connection is closed, all queued control frames aren't of any use anymore frames := []frames.Frame{ccf}
// discard them and queue the ConnectionCloseFrame encLevel, sealer := p.cryptoSetup.GetSealer()
p.controlFrames = []frames.Frame{ccf} ph := p.getPublicHeader(encLevel)
return p.packPacket(nil, leastUnacked, nil) raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
} }
// RetransmitNonForwardSecurePacket retransmits a handshake packet, that was sent with less than forward-secure encryption func (p *packetPacker) PackAckPacket() (*packedPacket, error) {
func (p *packetPacker) RetransmitNonForwardSecurePacket(stopWaitingFrame *frames.StopWaitingFrame, packet *ackhandler.Packet) (*packedPacket, error) { if p.ackFrame == nil {
return nil, errors.New("packet packer BUG: no ack frame queued")
}
encLevel, sealer := p.cryptoSetup.GetSealer()
ph := p.getPublicHeader(encLevel)
frames := []frames.Frame{p.ackFrame}
if p.stopWaiting != nil {
p.stopWaiting.PacketNumber = ph.PacketNumber
p.stopWaiting.PacketNumberLen = ph.PacketNumberLen
frames = append(frames, p.stopWaiting)
p.stopWaiting = nil
}
p.ackFrame = nil
raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, err
}
// PackHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption
func (p *packetPacker) PackHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure { if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment") return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment")
} }
if stopWaitingFrame == nil { sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel)
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")
}
return p.packPacket(stopWaitingFrame, 0, packet)
}
// PackPacket packs a new packet
// the stopWaitingFrame is *guaranteed* to be included in the next packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
p.controlFrames = append(p.controlFrames, controlFrames...)
return p.packPacket(stopWaitingFrame, leastUnacked, nil)
}
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber, handshakePacketToRetransmit *ackhandler.Packet) (*packedPacket, error) {
// handshakePacketToRetransmit is only set for handshake retransmissions
isHandshakeRetransmission := (handshakePacketToRetransmit != nil)
var sealFunc handshake.Sealer
var encLevel protocol.EncryptionLevel
if isHandshakeRetransmission {
var err error
encLevel = handshakePacketToRetransmit.EncryptionLevel
sealFunc, err = p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
if err != nil {
return nil, err
}
} else {
encLevel, sealFunc = p.cryptoSetup.GetSealer()
}
currentPacketNumber := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, leastUnacked)
responsePublicHeader := &PublicHeader{
ConnectionID: p.connectionID,
PacketNumber: currentPacketNumber,
PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParameters.TruncateConnectionID(),
}
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
responsePublicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
}
if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure {
responsePublicHeader.VersionFlag = true
responsePublicHeader.VersionNumber = p.version
}
publicHeaderLength, err := responsePublicHeader.GetLength(p.perspective)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if p.stopWaiting == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")
}
ph := p.getPublicHeader(packet.EncryptionLevel)
p.stopWaiting.PacketNumber = ph.PacketNumber
p.stopWaiting.PacketNumberLen = ph.PacketNumberLen
frames := append([]frames.Frame{p.stopWaiting}, packet.Frames...)
p.stopWaiting = nil
raw, err := p.writeAndSealPacket(ph, frames, sealer)
return &packedPacket{
number: ph.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: packet.EncryptionLevel,
}, err
}
if stopWaitingFrame != nil { // PackPacket packs a new packet
stopWaitingFrame.PacketNumber = currentPacketNumber // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
stopWaitingFrame.PacketNumberLen = packetNumberLen func (p *packetPacker) PackPacket() (*packedPacket, error) {
if p.streamFramer.HasCryptoStreamFrame() {
return p.packCryptoPacket()
} }
// we're packing a ConnectionClose, don't add any StreamFrames encLevel, sealer := p.cryptoSetup.GetSealer()
var isConnectionClose bool
if len(p.controlFrames) == 1 { publicHeader := p.getPublicHeader(encLevel)
_, isConnectionClose = p.controlFrames[0].(*frames.ConnectionCloseFrame) publicHeaderLength, err := publicHeader.GetLength(p.perspective)
if err != nil {
return nil, err
}
if p.stopWaiting != nil {
p.stopWaiting.PacketNumber = publicHeader.PacketNumber
p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen
} }
var payloadFrames []frames.Frame maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength
if isHandshakeRetransmission { payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel))
payloadFrames = append(payloadFrames, stopWaitingFrame) if err != nil {
// don't retransmit Acks and StopWaitings return nil, err
for _, f := range handshakePacketToRetransmit.Frames {
switch f.(type) {
case *frames.AckFrame:
continue
case *frames.StopWaitingFrame:
continue
}
payloadFrames = append(payloadFrames, f)
}
} else if isConnectionClose {
payloadFrames = []frames.Frame{p.controlFrames[0]}
} else {
maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength
if !p.isForwardSecure {
maxSize -= protocol.NonForwardSecurePacketSizeReduction
}
payloadFrames, err = p.composeNextPacket(stopWaitingFrame, maxSize)
if err != nil {
return nil, err
}
} }
// Check if we have enough frames to send // Check if we have enough frames to send
@ -158,71 +145,76 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea
return nil, nil return nil, nil
} }
// Don't send out packets that only contain a StopWaitingFrame // Don't send out packets that only contain a StopWaitingFrame
if len(payloadFrames) == 1 && stopWaitingFrame != nil { if len(payloadFrames) == 1 && p.stopWaiting != nil {
return nil, nil return nil, nil
} }
p.stopWaiting = nil
p.ackFrame = nil
raw := getPacketBuffer() raw, err := p.writeAndSealPacket(publicHeader, payloadFrames, sealer)
buffer := bytes.NewBuffer(raw) if err != nil {
if err = responsePublicHeader.Write(buffer, p.version, p.perspective); err != nil {
return nil, err return nil, err
} }
payloadStartIndex := buffer.Len()
var hasNonCryptoStreamData bool // does this frame contain any stream frame on a stream > 1
for _, frame := range payloadFrames {
if sf, ok := frame.(*frames.StreamFrame); ok && sf.StreamID != 1 {
hasNonCryptoStreamData = true
}
err = frame.Write(buffer, p.version)
if err != nil {
return nil, err
}
}
if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize {
return nil, errors.New("PacketPacker BUG: packet too large")
}
raw = raw[0:buffer.Len()]
_ = sealFunc(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], currentPacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+12]
if hasNonCryptoStreamData && encLevel <= protocol.EncryptionUnencrypted {
return nil, qerr.AttemptToSendUnencryptedStreamData
}
num := p.packetNumberGenerator.Pop()
if num != currentPacketNumber {
return nil, errors.New("PacketPacker BUG: Peeked and Popped packet numbers do not match.")
}
return &packedPacket{ return &packedPacket{
number: currentPacketNumber, number: publicHeader.PacketNumber,
raw: raw, raw: raw,
frames: payloadFrames, frames: payloadFrames,
encryptionLevel: encLevel, encryptionLevel: encLevel,
}, nil }, nil
} }
func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFrame, maxFrameSize protocol.ByteCount) ([]frames.Frame, error) { func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream()
publicHeader := p.getPublicHeader(encLevel)
publicHeaderLength, err := publicHeader.GetLength(p.perspective)
if err != nil {
return nil, err
}
maxLen := protocol.MaxFrameAndPublicHeaderSize - protocol.NonForwardSecurePacketSizeReduction - publicHeaderLength
frames := []frames.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)}
raw, err := p.writeAndSealPacket(publicHeader, frames, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
number: publicHeader.PacketNumber,
raw: raw,
frames: frames,
encryptionLevel: encLevel,
}, nil
}
func (p *packetPacker) composeNextPacket(
maxFrameSize protocol.ByteCount,
canSendStreamFrames bool,
) ([]frames.Frame, error) {
var payloadLength protocol.ByteCount var payloadLength protocol.ByteCount
var payloadFrames []frames.Frame var payloadFrames []frames.Frame
if stopWaitingFrame != nil { // STOP_WAITING and ACK will always fit
payloadFrames = append(payloadFrames, stopWaitingFrame) if p.stopWaiting != nil {
minLength, err := stopWaitingFrame.MinLength(p.version) payloadFrames = append(payloadFrames, p.stopWaiting)
l, err := p.stopWaiting.MinLength(p.version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
payloadLength += minLength payloadLength += l
}
if p.ackFrame != nil {
payloadFrames = append(payloadFrames, p.ackFrame)
l, err := p.ackFrame.MinLength(p.version)
if err != nil {
return nil, err
}
payloadLength += l
} }
for len(p.controlFrames) > 0 { for len(p.controlFrames) > 0 {
frame := p.controlFrames[len(p.controlFrames)-1] frame := p.controlFrames[len(p.controlFrames)-1]
minLength, _ := frame.MinLength(p.version) // controlFrames does not contain any StopWaitingFrames. So it will *never* return an error minLength, err := frame.MinLength(p.version)
if err != nil {
return nil, err
}
if payloadLength+minLength > maxFrameSize { if payloadLength+minLength > maxFrameSize {
break break
} }
@ -235,6 +227,10 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra
return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize) return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize)
} }
if !canSendStreamFrames {
return payloadFrames, nil
}
// temporarily increase the maxFrameSize by 2 bytes // temporarily increase the maxFrameSize by 2 bytes
// this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set // this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set
// however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size // however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size
@ -257,10 +253,79 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra
return payloadFrames, nil return payloadFrames, nil
} }
func (p *packetPacker) QueueControlFrameForNextPacket(f frames.Frame) { func (p *packetPacker) QueueControlFrame(frame frames.Frame) {
p.controlFrames = append(p.controlFrames, f) switch f := frame.(type) {
case *frames.StopWaitingFrame:
p.stopWaiting = f
case *frames.AckFrame:
p.ackFrame = f
default:
p.controlFrames = append(p.controlFrames, f)
}
} }
func (p *packetPacker) SetForwardSecure() { func (p *packetPacker) getPublicHeader(encLevel protocol.EncryptionLevel) *PublicHeader {
p.isForwardSecure = true pnum := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(pnum, p.leastUnacked)
publicHeader := &PublicHeader{
ConnectionID: p.connectionID,
PacketNumber: pnum,
PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParameters.TruncateConnectionID(),
}
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure {
publicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
}
if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure {
publicHeader.VersionFlag = true
publicHeader.VersionNumber = p.version
}
return publicHeader
}
func (p *packetPacker) writeAndSealPacket(
publicHeader *PublicHeader,
payloadFrames []frames.Frame,
sealer handshake.Sealer,
) ([]byte, error) {
raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw)
if err := publicHeader.Write(buffer, p.version, p.perspective); err != nil {
return nil, err
}
payloadStartIndex := buffer.Len()
for _, frame := range payloadFrames {
err := frame.Write(buffer, p.version)
if err != nil {
return nil, err
}
}
if protocol.ByteCount(buffer.Len()+12) > protocol.MaxPacketSize {
return nil, errors.New("PacketPacker BUG: packet too large")
}
raw = raw[0:buffer.Len()]
_ = sealer(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], publicHeader.PacketNumber, raw[:payloadStartIndex])
raw = raw[0 : buffer.Len()+12]
num := p.packetNumberGenerator.Pop()
if num != publicHeader.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
return raw, nil
}
func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool {
if p.perspective == protocol.PerspectiveClient {
return encLevel >= protocol.EncryptionSecure
}
return encLevel == protocol.EncryptionForwardSecure
}
func (p *packetPacker) SetLeastUnacked(leastUnacked protocol.PacketNumber) {
p.leastUnacked = leastUnacked
} }

View File

@ -10,6 +10,11 @@ import (
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
type unpackedPacket struct {
encryptionLevel protocol.EncryptionLevel
frames []frames.Frame
}
type quicAEAD interface { type quicAEAD interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
} }

View File

@ -31,7 +31,7 @@ type StreamID uint32
type ByteCount uint64 type ByteCount uint64
// MaxByteCount is the maximum value of a ByteCount // MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = math.MaxUint64 const MaxByteCount = ByteCount(math.MaxUint64)
// MaxReceivePacketSize maximum packet size of any QUIC packet, based on // MaxReceivePacketSize maximum packet size of any QUIC packet, based on
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header, // ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,

View File

@ -39,21 +39,21 @@ const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB
// This is the value that Google servers are using // This is the value that Google servers are using
const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB
// MaxReceiveStreamFlowControlWindowServer is the maximum stream-level flow control window for receiving data // DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB const DefaultMaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB
// MaxReceiveConnectionFlowControlWindowServer is the connection-level flow control window for receiving data // DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB const DefaultMaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB
// MaxReceiveStreamFlowControlWindowClient is the maximum stream-level flow control window for receiving data, for the client // DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client
// This is the value that Chromium is using // This is the value that Chromium is using
const MaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB const DefaultMaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB
// MaxReceiveConnectionFlowControlWindowClient is the connection-level flow control window for receiving data, for the server // DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client
// This is the value that Google servers are using // This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB const DefaultMaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window // ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
// This is the value that Chromium is using // This is the value that Chromium is using
@ -128,8 +128,8 @@ const MaxIdleTimeoutServer = 1 * time.Minute
// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server // MaxIdleTimeoutClient is the idle timeout that the client suggests to the server
const MaxIdleTimeoutClient = 2 * time.Minute const MaxIdleTimeoutClient = 2 * time.Minute
// MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds. // DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
const MaxTimeForCryptoHandshake = 10 * time.Second const DefaultHandshakeTimeout = 10 * time.Second
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed // ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
// after this time all information about the old connection will be deleted // after this time all information about the old connection will be deleted

View File

@ -4,9 +4,9 @@ import (
"bytes" "bytes"
"errors" "errors"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
var ( var (

View File

@ -6,8 +6,8 @@ import (
"errors" "errors"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type publicReset struct { type publicReset struct {

View File

@ -3,7 +3,7 @@ package qerr
import ( import (
"fmt" "fmt"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/internal/utils"
) )
// ErrorCode can be used as a normal error without reason. // ErrorCode can be used as a normal error without reason.
@ -31,6 +31,16 @@ func (e *QuicError) Error() string {
return fmt.Sprintf("%s: %s", e.ErrorCode.String(), e.ErrorMessage) return fmt.Sprintf("%s: %s", e.ErrorCode.String(), e.ErrorMessage)
} }
func (e *QuicError) Timeout() bool {
switch e.ErrorCode {
case NetworkIdleTimeout,
HandshakeTimeout,
TimeoutsWithOpenStreams:
return true
}
return false
}
// ToQuicError converts an arbitrary error to a QuicError. It leaves QuicErrors // ToQuicError converts an arbitrary error to a QuicError. It leaves QuicErrors
// unchanged, and properly handles `ErrorCode`s. // unchanged, and properly handles `ErrorCode`s.
func ToQuicError(err error) *QuicError { func ToQuicError(err error) *QuicError {

View File

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"crypto/tls"
"errors" "errors"
"net" "net"
"sync" "sync"
@ -9,9 +10,9 @@ import (
"github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
// packetHandler handles packets // packetHandler handles packets
@ -19,11 +20,13 @@ type packetHandler interface {
Session Session
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
run() error run() error
closeRemote(error)
} }
// A Listener of QUIC // A Listener of QUIC
type server struct { type server struct {
config *Config tlsConf *tls.Config
config *Config
conn net.PacketConn conn net.PacketConn
@ -38,14 +41,15 @@ type server struct {
sessionQueue chan Session sessionQueue chan Session
errorChan chan struct{} errorChan chan struct{}
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, config *Config) (packetHandler, <-chan handshakeEvent, error) newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, <-chan handshakeEvent, error)
} }
var _ Listener = &server{} var _ Listener = &server{}
// ListenAddr creates a QUIC server listening on a given address. // ListenAddr creates a QUIC server listening on a given address.
// The listener is not active until Serve() is called. // The listener is not active until Serve() is called.
func ListenAddr(addr string, config *Config) (Listener, error) { // The tls.Config must not be nil, the quic.Config may be nil.
func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -54,13 +58,14 @@ func ListenAddr(addr string, config *Config) (Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Listen(conn, config) return Listen(conn, tlsConf, config)
} }
// Listen listens for QUIC connections on a given net.PacketConn. // Listen listens for QUIC connections on a given net.PacketConn.
// The listener is not active until Serve() is called. // The listener is not active until Serve() is called.
func Listen(conn net.PacketConn, config *Config) (Listener, error) { // The tls.Config must not be nil, the quic.Config may be nil.
certChain := crypto.NewCertChain(config.TLSConfig) func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
certChain := crypto.NewCertChain(tlsConf)
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
if err != nil { if err != nil {
return nil, err return nil, err
@ -72,6 +77,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
s := &server{ s := &server{
conn: conn, conn: conn,
tlsConf: tlsConf,
config: populateServerConfig(config), config: populateServerConfig(config),
certChain: certChain, certChain: certChain,
scfg: scfg, scfg: scfg,
@ -101,20 +107,42 @@ var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool {
return sourceAddr == stk.remoteAddr return sourceAddr == stk.remoteAddr
} }
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config { func populateServerConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions versions := config.Versions
if len(versions) == 0 { if len(versions) == 0 {
versions = protocol.SupportedVersions versions = protocol.SupportedVersions
} }
vsa := defaultAcceptSTK vsa := defaultAcceptSTK
if config.AcceptSTK != nil { if config.AcceptSTK != nil {
vsa = config.AcceptSTK vsa = config.AcceptSTK
} }
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
}
return &Config{ return &Config{
TLSConfig: config.TLSConfig, Versions: versions,
Versions: versions, HandshakeTimeout: handshakeTimeout,
AcceptSTK: vsa, AcceptSTK: vsa,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
} }
} }
@ -238,6 +266,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
version, version,
hdr.ConnectionID, hdr.ConnectionID,
s.scfg, s.scfg,
s.tlsConf,
s.config, s.config,
) )
if err != nil { if err != nil {

View File

@ -1,10 +1,11 @@
package quic package quic
import ( import (
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync/atomic" "sync"
"time" "time"
"github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/ackhandler"
@ -12,9 +13,9 @@ import (
"github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
) )
type unpacker interface { type unpacker interface {
@ -31,7 +32,6 @@ type receivedPacket struct {
var ( var (
errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream")
errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream") errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream")
errSessionAlreadyClosed = errors.New("cannot close session; it was already closed before")
) )
var ( var (
@ -54,6 +54,7 @@ type session struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
perspective protocol.Perspective perspective protocol.Perspective
version protocol.VersionNumber version protocol.VersionNumber
tlsConf *tls.Config
config *Config config *Config
conn connection conn connection
@ -77,8 +78,10 @@ type session struct {
sendingScheduled chan struct{} sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate. // closeChan is used to notify the run loop that it should terminate.
closeChan chan closeError closeChan chan closeError
// runClosed is closed once the run loop exits
// it is used to block Close() and WaitUntilClosed()
runClosed chan struct{} runClosed chan struct{}
closed uint32 // atomic bool closeOnce sync.Once
// when we receive too many undecryptable packets during the handshake, we send a Public reset // when we receive too many undecryptable packets during the handshake, we send a Public reset
// but only after a time of protocol.PublicResetTimeout has passed // but only after a time of protocol.PublicResetTimeout has passed
@ -97,8 +100,6 @@ type session struct {
// it receives at most 3 handshake events: 2 when the encryption level changes, and one error // it receives at most 3 handshake events: 2 when the encryption level changes, and one error
handshakeChan chan<- handshakeEvent handshakeChan chan<- handshakeEvent
nextAckScheduledTime time.Time
connectionParameters handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
lastRcvdPacketNumber protocol.PacketNumber lastRcvdPacketNumber protocol.PacketNumber
@ -109,9 +110,10 @@ type session struct {
sessionCreationTime time.Time sessionCreationTime time.Time
lastNetworkActivityTime time.Time lastNetworkActivityTime time.Time
timer *time.Timer timer *utils.Timer
currentDeadline time.Time // keepAlivePingSent stores whether a Ping frame was sent to the peer or not
timerRead bool // it is reset as soon as we receive a packet from the peer
keepAlivePingSent bool
} }
var _ Session = &session{} var _ Session = &session{}
@ -122,6 +124,7 @@ func newSession(
v protocol.VersionNumber, v protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
sCfg *handshake.ServerConfig, sCfg *handshake.ServerConfig,
tlsConf *tls.Config,
config *Config, config *Config,
) (packetHandler, <-chan handshakeEvent, error) { ) (packetHandler, <-chan handshakeEvent, error) {
s := &session{ s := &session{
@ -130,46 +133,8 @@ func newSession(
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
config: config, config: config,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
} }
return s.setup(sCfg, "", nil)
s.setup()
cryptoStream, _ := s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan
verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool {
if hstk == nil {
return config.AcceptSTK(clientAddr, nil)
}
return config.AcceptSTK(
clientAddr,
&STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime},
)
}
var err error
s.cryptoSetup, err = newCryptoSetup(
connectionID,
conn.RemoteAddr(),
v,
sCfg,
cryptoStream,
s.connectionParameters,
config.Versions,
verifySourceAddr,
aeadChanged,
)
if err != nil {
return nil, nil, err
}
s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, handshakeChan, err
} }
// declare this as a variable, such that we can it mock it in the tests // declare this as a variable, such that we can it mock it in the tests
@ -178,6 +143,7 @@ var newClientSession = func(
hostname string, hostname string,
v protocol.VersionNumber, v protocol.VersionNumber,
connectionID protocol.ConnectionID, connectionID protocol.ConnectionID,
tlsConf *tls.Config,
config *Config, config *Config,
negotiatedVersions []protocol.VersionNumber, negotiatedVersions []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) { ) (packetHandler, <-chan handshakeEvent, error) {
@ -186,68 +152,92 @@ var newClientSession = func(
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
tlsConf: tlsConf,
config: config, config: config,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
} }
return s.setup(nil, hostname, negotiatedVersions)
}
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) func (s *session) setup(
s.setup() scfg *handshake.ServerConfig,
hostname string,
negotiatedVersions []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
aeadChanged := make(chan protocol.EncryptionLevel, 2) aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3) handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan s.handshakeChan = handshakeChan
cryptoStream, _ := s.OpenStream() s.runClosed = make(chan struct{})
var err error s.handshakeCompleteChan = make(chan error, 1)
s.cryptoSetup, err = newCryptoSetupClient(
hostname,
connectionID,
v,
cryptoStream,
config.TLSConfig,
s.connectionParameters,
aeadChanged,
&handshake.TransportParameters{RequestConnectionIDTruncation: config.RequestConnectionIDTruncation},
negotiatedVersions,
)
if err != nil {
return nil, nil, err
}
s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, handshakeChan, err
}
// setup is called from newSession and newClientSession and initializes values that are independent of the perspective
func (s *session) setup() {
s.rttStats = &congestion.RTTStats{}
flowControlManager := flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats)
sentPacketHandler := ackhandler.NewSentPacketHandler(s.rttStats)
now := time.Now()
s.sentPacketHandler = sentPacketHandler
s.flowControlManager = flowControlManager
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets) s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1) s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1) s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets) s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.aeadChanged = make(chan protocol.EncryptionLevel, 2)
s.runClosed = make(chan struct{})
s.handshakeCompleteChan = make(chan error, 1)
s.timer = time.NewTimer(0) s.timer = utils.NewTimer()
now := time.Now()
s.lastNetworkActivityTime = now s.lastNetworkActivityTime = now
s.sessionCreationTime = now s.sessionCreationTime = now
s.rttStats = &congestion.RTTStats{}
s.connectionParameters = handshake.NewConnectionParamatersManager(s.perspective, s.version,
s.config.MaxReceiveStreamFlowControlWindow, s.config.MaxReceiveConnectionFlowControlWindow)
s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats)
s.flowControlManager = flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats)
s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler()
s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters) s.streamsMap = newStreamsMap(s.newStream, s.perspective, s.connectionParameters)
s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager) s.streamFramer = newStreamFramer(s.streamsMap, s.flowControlManager)
var err error
if s.perspective == protocol.PerspectiveServer {
cryptoStream, _ := s.GetOrOpenStream(1)
_, _ = s.AcceptStream() // don't expose the crypto stream
verifySourceAddr := func(clientAddr net.Addr, hstk *handshake.STK) bool {
var stk *STK
if hstk != nil {
stk = &STK{remoteAddr: hstk.RemoteAddr, sentTime: hstk.SentTime}
}
return s.config.AcceptSTK(clientAddr, stk)
}
s.cryptoSetup, err = newCryptoSetup(
s.connectionID,
s.conn.RemoteAddr(),
s.version,
scfg,
cryptoStream,
s.connectionParameters,
s.config.Versions,
verifySourceAddr,
aeadChanged,
)
} else {
cryptoStream, _ := s.OpenStream()
s.cryptoSetup, err = newCryptoSetupClient(
hostname,
s.connectionID,
s.version,
cryptoStream,
s.tlsConf,
s.connectionParameters,
aeadChanged,
&handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation},
negotiatedVersions,
)
}
if err != nil {
return nil, nil, err
}
s.packer = newPacketPacker(s.connectionID,
s.cryptoSetup,
s.connectionParameters,
s.streamFramer,
s.perspective,
s.version,
)
s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version}
return s, handshakeChan, nil
} }
// run the session main loop // run the session main loop
@ -276,8 +266,8 @@ runLoop:
select { select {
case closeErr = <-s.closeChan: case closeErr = <-s.closeChan:
break runLoop break runLoop
case <-s.timer.C: case <-s.timer.Chan():
s.timerRead = true s.timer.SetRead()
// We do all the interesting stuff after the switch statement, so // We do all the interesting stuff after the switch statement, so
// nothing to see here. // nothing to see here.
case <-s.sendingScheduled: case <-s.sendingScheduled:
@ -290,7 +280,7 @@ runLoop:
s.tryQueueingUndecryptablePacket(p) s.tryQueueingUndecryptablePacket(p)
continue continue
} }
s.close(err) s.closeLocal(err)
continue continue
} }
// This is a bit unclean, but works properly, since the packet always // This is a bit unclean, but works properly, since the packet always
@ -303,32 +293,35 @@ runLoop:
close(s.handshakeChan) close(s.handshakeChan)
close(s.handshakeCompleteChan) close(s.handshakeCompleteChan)
} else { } else {
if l == protocol.EncryptionForwardSecure {
s.packer.SetForwardSecure()
}
s.tryDecryptingQueuedPackets() s.tryDecryptingQueuedPackets()
s.handshakeChan <- handshakeEvent{encLevel: l} s.handshakeChan <- handshakeEvent{encLevel: l}
} }
} }
now := time.Now() now := time.Now()
if s.sentPacketHandler.GetAlarmTimeout().Before(now) { if timeout := s.sentPacketHandler.GetAlarmTimeout(); !timeout.IsZero() && timeout.Before(now) {
// This could cause packets to be retransmitted, so check it before trying // This could cause packets to be retransmitted, so check it before trying
// to send packets. // to send packets.
s.sentPacketHandler.OnAlarm() s.sentPacketHandler.OnAlarm()
} }
if s.config.KeepAlive && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.idleTimeout()/2 {
// send the PING frame since there is no activity in the session
s.packer.QueueControlFrame(&frames.PingFrame{})
s.keepAlivePingSent = true
}
if err := s.sendPacket(); err != nil { if err := s.sendPacket(); err != nil {
s.close(err) s.closeLocal(err)
} }
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 { if !s.receivedTooManyUndecrytablePacketsTime.IsZero() && s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout).Before(now) && len(s.undecryptablePackets) != 0 {
s.close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received")) s.closeLocal(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
} }
if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() { if now.Sub(s.lastNetworkActivityTime) >= s.idleTimeout() {
s.close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity.")) s.closeLocal(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
} }
if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= protocol.MaxTimeForCryptoHandshake { if !s.handshakeComplete && now.Sub(s.sessionCreationTime) >= s.config.HandshakeTimeout {
s.close(qerr.Error(qerr.NetworkIdleTimeout, "Crypto handshake did not complete in time.")) s.closeLocal(qerr.Error(qerr.HandshakeTimeout, "Crypto handshake did not complete in time."))
} }
s.garbageCollectStreams() s.garbageCollectStreams()
} }
@ -344,37 +337,33 @@ runLoop:
return closeErr.err return closeErr.err
} }
func (s *session) maybeResetTimer() { func (s *session) WaitUntilClosed() {
nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) <-s.runClosed
}
if !s.nextAckScheduledTime.IsZero() { func (s *session) maybeResetTimer() {
nextDeadline = utils.MinTime(nextDeadline, s.nextAckScheduledTime) var deadline time.Time
if s.config.KeepAlive && s.handshakeComplete && !s.keepAlivePingSent {
deadline = s.lastNetworkActivityTime.Add(s.idleTimeout() / 2)
} else {
deadline = s.lastNetworkActivityTime.Add(s.idleTimeout())
}
if ackAlarm := s.receivedPacketHandler.GetAlarmTimeout(); !ackAlarm.IsZero() {
deadline = utils.MinTime(deadline, ackAlarm)
} }
if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() { if lossTime := s.sentPacketHandler.GetAlarmTimeout(); !lossTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, lossTime) deadline = utils.MinTime(deadline, lossTime)
} }
if !s.handshakeComplete { if !s.handshakeComplete {
handshakeDeadline := s.sessionCreationTime.Add(protocol.MaxTimeForCryptoHandshake) handshakeDeadline := s.sessionCreationTime.Add(s.config.HandshakeTimeout)
nextDeadline = utils.MinTime(nextDeadline, handshakeDeadline) deadline = utils.MinTime(deadline, handshakeDeadline)
} }
if !s.receivedTooManyUndecrytablePacketsTime.IsZero() { if !s.receivedTooManyUndecrytablePacketsTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout)) deadline = utils.MinTime(deadline, s.receivedTooManyUndecrytablePacketsTime.Add(protocol.PublicResetTimeout))
} }
if nextDeadline.Equal(s.currentDeadline) { s.timer.Reset(deadline)
// No need to reset the timer
return
}
// We need to drain the timer if the value from its channel was not read yet.
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
if !s.timer.Stop() && !s.timerRead {
<-s.timer.C
}
s.timer.Reset(nextDeadline.Sub(time.Now()))
s.timerRead = false
s.currentDeadline = nextDeadline
} }
func (s *session) idleTimeout() time.Duration { func (s *session) idleTimeout() time.Duration {
@ -398,6 +387,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
} }
s.lastNetworkActivityTime = p.rcvTime s.lastNetworkActivityTime = p.rcvTime
s.keepAlivePingSent = false
hdr := p.publicHeader hdr := p.publicHeader
data := p.data data := p.data
@ -433,19 +423,8 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
// Only do this after decrypting, so we are sure the packet is not attacker-controlled // Only do this after decrypting, so we are sure the packet is not attacker-controlled
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, packet.IsRetransmittable()) isRetransmittable := ackhandler.HasRetransmittableFrames(packet.frames)
// ignore duplicate packets if err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, isRetransmittable); err != nil {
if err == ackhandler.ErrDuplicatePacket {
utils.Infof("Ignoring packet 0x%x due to ErrDuplicatePacket", hdr.PacketNumber)
return nil
}
// ignore packets with packet numbers smaller than the LeastUnacked of a StopWaiting
if err == ackhandler.ErrPacketSmallerThanLastStopWaiting {
utils.Infof("Ignoring packet 0x%x due to ErrPacketSmallerThanLastStopWaiting", hdr.PacketNumber)
return nil
}
if err != nil {
return err return err
} }
@ -462,7 +441,7 @@ func (s *session) handleFrames(fs []frames.Frame) error {
case *frames.AckFrame: case *frames.AckFrame:
err = s.handleAckFrame(frame) err = s.handleAckFrame(frame)
case *frames.ConnectionCloseFrame: case *frames.ConnectionCloseFrame:
s.registerClose(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
case *frames.GoawayFrame: case *frames.GoawayFrame:
err = errors.New("unimplemented: handling GOAWAY frames") err = errors.New("unimplemented: handling GOAWAY frames")
case *frames.StopWaitingFrame: case *frames.StopWaitingFrame:
@ -548,48 +527,31 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error {
return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime)
} }
func (s *session) registerClose(e error, remoteClose bool) error { func (s *session) closeLocal(e error) {
// Only close once s.closeOnce.Do(func() {
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { s.closeChan <- closeError{err: e, remote: false}
return errSessionAlreadyClosed })
} }
if e == nil { func (s *session) closeRemote(e error) {
e = qerr.PeerGoingAway s.closeOnce.Do(func() {
} s.closeChan <- closeError{err: e, remote: true}
})
if e == errCloseSessionForNewVersion {
s.streamsMap.CloseWithError(e)
s.closeStreamsWithError(e)
}
s.closeChan <- closeError{err: e, remote: remoteClose}
return nil
} }
// Close the connection. If err is nil it will be set to qerr.PeerGoingAway. // Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
// It waits until the run loop has stopped before returning // It waits until the run loop has stopped before returning
func (s *session) Close(e error) error { func (s *session) Close(e error) error {
err := s.registerClose(e, false) s.closeLocal(e)
if err == errSessionAlreadyClosed {
return nil
}
// wait for the run loop to finish
<-s.runClosed <-s.runClosed
return err return nil
}
// close the connection. Use this when called from the run loop
func (s *session) close(e error) error {
err := s.registerClose(e, false)
if err == errSessionAlreadyClosed {
return nil
}
return err
} }
func (s *session) handleCloseError(closeErr closeError) error { func (s *session) handleCloseError(closeErr closeError) error {
if closeErr.err == nil {
closeErr.err = qerr.PeerGoingAway
}
var quicErr *qerr.QuicError var quicErr *qerr.QuicError
var ok bool var ok bool
if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok { if quicErr, ok = closeErr.err.(*qerr.QuicError); !ok {
@ -602,13 +564,12 @@ func (s *session) handleCloseError(closeErr closeError) error {
utils.Errorf("Closing session with error: %s", closeErr.err.Error()) utils.Errorf("Closing session with error: %s", closeErr.err.Error())
} }
s.streamsMap.CloseWithError(quicErr)
if closeErr.err == errCloseSessionForNewVersion { if closeErr.err == errCloseSessionForNewVersion {
return nil return nil
} }
s.streamsMap.CloseWithError(quicErr)
s.closeStreamsWithError(quicErr)
// If this is a remote close we're done here // If this is a remote close we're done here
if closeErr.remote { if closeErr.remote {
return nil return nil
@ -620,27 +581,37 @@ func (s *session) handleCloseError(closeErr closeError) error {
return s.sendConnectionClose(quicErr) return s.sendConnectionClose(quicErr)
} }
func (s *session) closeStreamsWithError(err error) {
s.streamsMap.Iterate(func(str *stream) (bool, error) {
str.Cancel(err)
return true, nil
})
}
func (s *session) sendPacket() error { func (s *session) sendPacket() error {
s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
// Get WindowUpdate frames
// this call triggers the flow controller to increase the flow control windows, if necessary
windowUpdateFrames := s.getWindowUpdateFrames()
for _, wuf := range windowUpdateFrames {
s.packer.QueueControlFrame(wuf)
}
ack := s.receivedPacketHandler.GetAckFrame()
if ack != nil {
s.packer.QueueControlFrame(ack)
}
// Repeatedly try sending until we don't have any more data, or run out of the congestion window // Repeatedly try sending until we don't have any more data, or run out of the congestion window
for { for {
if !s.sentPacketHandler.SendingAllowed() { if !s.sentPacketHandler.SendingAllowed() {
return nil if ack == nil {
} return nil
}
var controlFrames []frames.Frame // If we aren't allowed to send, at least try sending an ACK frame
swf := s.sentPacketHandler.GetStopWaitingFrame(false)
// get WindowUpdate frames if swf != nil {
// this call triggers the flow controller to increase the flow control windows, if necessary s.packer.QueueControlFrame(swf)
windowUpdateFrames := s.getWindowUpdateFrames() }
for _, wuf := range windowUpdateFrames { packet, err := s.packer.PackAckPacket()
controlFrames = append(controlFrames, wuf) if err != nil {
return err
}
return s.sendPackedPacket(packet)
} }
// check for retransmissions first // check for retransmissions first
@ -649,75 +620,67 @@ func (s *session) sendPacket() error {
if retransmitPacket == nil { if retransmitPacket == nil {
break break
} }
utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber)
if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure { if retransmitPacket.EncryptionLevel != protocol.EncryptionForwardSecure {
utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) if s.handshakeComplete {
stopWaitingFrame := s.sentPacketHandler.GetStopWaitingFrame(true) // Don't retransmit handshake packets when the handshake is complete
var packet *packedPacket
packet, err := s.packer.RetransmitNonForwardSecurePacket(stopWaitingFrame, retransmitPacket)
if err != nil {
return err
}
if packet == nil {
continue continue
} }
err = s.sendPackedPacket(packet) utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber)
s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true))
packet, err := s.packer.PackHandshakeRetransmission(retransmitPacket)
if err != nil { if err != nil {
return err return err
} }
continue if err = s.sendPackedPacket(packet); err != nil {
return err
}
} else { } else {
utils.Debugf("\tDequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber)
// resend the frames that were in the packet // resend the frames that were in the packet
for _, frame := range retransmitPacket.GetFramesForRetransmission() { for _, frame := range retransmitPacket.GetFramesForRetransmission() {
switch frame.(type) { switch f := frame.(type) {
case *frames.StreamFrame: case *frames.StreamFrame:
s.streamFramer.AddFrameForRetransmission(frame.(*frames.StreamFrame)) s.streamFramer.AddFrameForRetransmission(f)
case *frames.WindowUpdateFrame: case *frames.WindowUpdateFrame:
// only retransmit WindowUpdates if the stream is not yet closed and the we haven't sent another WindowUpdate with a higher ByteOffset for the stream // only retransmit WindowUpdates if the stream is not yet closed and the we haven't sent another WindowUpdate with a higher ByteOffset for the stream
var currentOffset protocol.ByteCount
f := frame.(*frames.WindowUpdateFrame)
currentOffset, err := s.flowControlManager.GetReceiveWindow(f.StreamID) currentOffset, err := s.flowControlManager.GetReceiveWindow(f.StreamID)
if err == nil && f.ByteOffset >= currentOffset { if err == nil && f.ByteOffset >= currentOffset {
controlFrames = append(controlFrames, frame) s.packer.QueueControlFrame(f)
} }
default: default:
controlFrames = append(controlFrames, frame) s.packer.QueueControlFrame(frame)
} }
} }
} }
} }
ack := s.receivedPacketHandler.GetAckFrame()
if ack != nil {
controlFrames = append(controlFrames, ack)
}
hasRetransmission := s.streamFramer.HasFramesForRetransmission() hasRetransmission := s.streamFramer.HasFramesForRetransmission()
var stopWaitingFrame *frames.StopWaitingFrame
if ack != nil || hasRetransmission { if ack != nil || hasRetransmission {
stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission)
if swf != nil {
s.packer.QueueControlFrame(swf)
}
} }
packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked()) packet, err := s.packer.PackPacket()
if err != nil { if err != nil || packet == nil {
return err return err
} }
if packet == nil { if err = s.sendPackedPacket(packet); err != nil {
return nil return err
}
// send every window update twice
for _, f := range windowUpdateFrames {
s.packer.QueueControlFrameForNextPacket(f)
} }
err = s.sendPackedPacket(packet) // send every window update twice
if err != nil { for _, f := range windowUpdateFrames {
return err s.packer.QueueControlFrame(f)
} }
s.nextAckScheduledTime = time.Time{} windowUpdateFrames = nil
ack = nil
} }
} }
func (s *session) sendPackedPacket(packet *packedPacket) error { func (s *session) sendPackedPacket(packet *packedPacket) error {
defer putPacketBuffer(packet.raw)
err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{ err := s.sentPacketHandler.SentPacket(&ackhandler.Packet{
PacketNumber: packet.number, PacketNumber: packet.number,
Frames: packet.frames, Frames: packet.frames,
@ -727,22 +690,19 @@ func (s *session) sendPackedPacket(packet *packedPacket) error {
if err != nil { if err != nil {
return err return err
} }
s.logPacket(packet) s.logPacket(packet)
return s.conn.Write(packet.raw)
err = s.conn.Write(packet.raw)
putPacketBuffer(packet.raw)
return err
} }
func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error {
packet, err := s.packer.PackConnectionClose(&frames.ConnectionCloseFrame{ErrorCode: quicErr.ErrorCode, ReasonPhrase: quicErr.ErrorMessage}, s.sentPacketHandler.GetLeastUnacked()) s.packer.SetLeastUnacked(s.sentPacketHandler.GetLeastUnacked())
packet, err := s.packer.PackConnectionClose(&frames.ConnectionCloseFrame{
ErrorCode: quicErr.ErrorCode,
ReasonPhrase: quicErr.ErrorMessage,
})
if err != nil { if err != nil {
return err return err
} }
if packet == nil {
return errors.New("Session BUG: expected packet not to be nil")
}
s.logPacket(packet) s.logPacket(packet)
return s.conn.Write(packet.raw) return s.conn.Write(packet.raw)
} }
@ -752,11 +712,9 @@ func (s *session) logPacket(packet *packedPacket) {
// We don't need to allocate the slices for calling the format functions // We don't need to allocate the slices for calling the format functions
return return
} }
if utils.Debug() { utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", packet.number, len(packet.raw), s.connectionID, packet.encryptionLevel)
utils.Debugf("-> Sending packet 0x%x (%d bytes), %s", packet.number, len(packet.raw), packet.encryptionLevel) for _, frame := range packet.frames {
for _, frame := range packet.frames { frames.LogFrame(frame, true)
frames.LogFrame(frame, true)
}
} }
} }
@ -790,27 +748,21 @@ func (s *session) WaitUntilHandshakeComplete() error {
} }
func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) {
s.packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{ s.packer.QueueControlFrame(&frames.RstStreamFrame{
StreamID: id, StreamID: id,
ByteOffset: offset, ByteOffset: offset,
}) })
s.scheduleSending() s.scheduleSending()
} }
func (s *session) newStream(id protocol.StreamID) (*stream, error) { func (s *session) newStream(id protocol.StreamID) *stream {
stream, err := newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager)
if err != nil {
return nil, err
}
// TODO: find a better solution for determining which streams contribute to connection level flow control // TODO: find a better solution for determining which streams contribute to connection level flow control
if id == 1 || id == 3 { if id == 1 || id == 3 {
s.flowControlManager.NewStream(id, false) s.flowControlManager.NewStream(id, false)
} else { } else {
s.flowControlManager.NewStream(id, true) s.flowControlManager.NewStream(id, true)
} }
return newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager)
return stream, nil
} }
// garbageCollectStreams goes through all streams and removes EOF'ed streams // garbageCollectStreams goes through all streams and removes EOF'ed streams
@ -844,6 +796,7 @@ func (s *session) scheduleSending() {
func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
if s.handshakeComplete { if s.handshakeComplete {
utils.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.publicHeader, len(p.data))
return return
} }
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets { if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
@ -875,11 +828,6 @@ func (s *session) getWindowUpdateFrames() []*frames.WindowUpdateFrame {
return res return res
} }
func (s *session) ackAlarmChanged(t time.Time) {
s.nextAckScheduledTime = t
s.maybeResetTimer()
}
func (s *session) LocalAddr() net.Addr { func (s *session) LocalAddr() net.Addr {
return s.conn.LocalAddr() return s.conn.LocalAddr()
} }

View File

@ -3,12 +3,14 @@ package quic
import ( import (
"fmt" "fmt"
"io" "io"
"net"
"sync" "sync"
"time"
"github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
// A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface // A Stream assembles the data from StreamFrames and provides a super-convenient Read-Interface
@ -40,31 +42,41 @@ type stream struct {
// resetRemotely is set if RegisterRemoteError() is called // resetRemotely is set if RegisterRemoteError() is called
resetRemotely utils.AtomicBool resetRemotely utils.AtomicBool
frameQueue *streamFrameSorter frameQueue *streamFrameSorter
newFrameOrErrCond sync.Cond readChan chan struct{}
readDeadline time.Time
dataForWriting []byte dataForWriting []byte
finSent utils.AtomicBool finSent utils.AtomicBool
rstSent utils.AtomicBool rstSent utils.AtomicBool
doneWritingOrErrCond sync.Cond writeChan chan struct{}
writeDeadline time.Time
flowControlManager flowcontrol.FlowControlManager flowControlManager flowcontrol.FlowControlManager
} }
type deadlineError struct{}
func (deadlineError) Error() string { return "deadline exceeded" }
func (deadlineError) Temporary() bool { return true }
func (deadlineError) Timeout() bool { return true }
var errDeadline net.Error = &deadlineError{}
// newStream creates a new Stream // newStream creates a new Stream
func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowControlManager flowcontrol.FlowControlManager) (*stream, error) { func newStream(StreamID protocol.StreamID,
s := &stream{ onData func(),
onReset func(protocol.StreamID, protocol.ByteCount),
flowControlManager flowcontrol.FlowControlManager) *stream {
return &stream{
onData: onData, onData: onData,
onReset: onReset, onReset: onReset,
streamID: StreamID, streamID: StreamID,
flowControlManager: flowControlManager, flowControlManager: flowControlManager,
frameQueue: newStreamFrameSorter(), frameQueue: newStreamFrameSorter(),
readChan: make(chan struct{}, 1),
writeChan: make(chan struct{}, 1),
} }
s.newFrameOrErrCond.L = &s.mutex
s.doneWritingOrErrCond.L = &s.mutex
return s, nil
} }
// Read implements io.Reader. It is not thread safe! // Read implements io.Reader. It is not thread safe!
@ -83,10 +95,10 @@ func (s *stream) Read(p []byte) (int, error) {
for bytesRead < len(p) { for bytesRead < len(p) {
s.mutex.Lock() s.mutex.Lock()
frame := s.frameQueue.Head() frame := s.frameQueue.Head()
if frame == nil && bytesRead > 0 { if frame == nil && bytesRead > 0 {
err = s.err
s.mutex.Unlock() s.mutex.Unlock()
return bytesRead, s.err return bytesRead, err
} }
var err error var err error
@ -96,11 +108,28 @@ func (s *stream) Read(p []byte) (int, error) {
err = s.err err = s.err
break break
} }
deadline := s.readDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
err = errDeadline
break
}
if frame != nil { if frame != nil {
s.readPosInFrame = int(s.readOffset - frame.Offset) s.readPosInFrame = int(s.readOffset - frame.Offset)
break break
} }
s.newFrameOrErrCond.Wait()
s.mutex.Unlock()
if deadline.IsZero() {
<-s.readChan
} else {
select {
case <-s.readChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
frame = s.frameQueue.Head() frame = s.frameQueue.Head()
} }
s.mutex.Unlock() s.mutex.Unlock()
@ -145,34 +174,49 @@ func (s *stream) Read(p []byte) (int, error) {
} }
func (s *stream) Write(p []byte) (int, error) { func (s *stream) Write(p []byte) (int, error) {
if s.resetLocally.Get() {
return 0, s.err
}
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.err != nil { if s.resetLocally.Get() || s.err != nil {
return 0, s.err return 0, s.err
} }
if len(p) == 0 { if len(p) == 0 {
return 0, nil return 0, nil
} }
s.dataForWriting = make([]byte, len(p)) s.dataForWriting = make([]byte, len(p))
copy(s.dataForWriting, p) copy(s.dataForWriting, p)
s.onData() s.onData()
for s.dataForWriting != nil && s.err == nil { var err error
s.doneWritingOrErrCond.Wait() for {
deadline := s.writeDeadline
if !deadline.IsZero() && !time.Now().Before(deadline) {
err = errDeadline
break
}
if s.dataForWriting == nil || s.err != nil {
break
}
s.mutex.Unlock()
if deadline.IsZero() {
<-s.writeChan
} else {
select {
case <-s.writeChan:
case <-time.After(deadline.Sub(time.Now())):
}
}
s.mutex.Lock()
} }
if err != nil {
return 0, err
}
if s.err != nil { if s.err != nil {
return 0, s.err return len(p) - len(s.dataForWriting), s.err
} }
return len(p), nil return len(p), nil
} }
@ -188,14 +232,12 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount {
func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
s.mutex.Lock() s.mutex.Lock()
if s.err != nil { defer s.mutex.Unlock()
s.mutex.Unlock()
return nil if s.err != nil || s.dataForWriting == nil {
}
if s.dataForWriting == nil {
s.mutex.Unlock()
return nil return nil
} }
var ret []byte var ret []byte
if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { if protocol.ByteCount(len(s.dataForWriting)) > maxBytes {
ret = s.dataForWriting[:maxBytes] ret = s.dataForWriting[:maxBytes]
@ -203,10 +245,9 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte {
} else { } else {
ret = s.dataForWriting ret = s.dataForWriting
s.dataForWriting = nil s.dataForWriting = nil
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
s.writeOffset += protocol.ByteCount(len(ret)) s.writeOffset += protocol.ByteCount(len(ret))
s.mutex.Unlock()
return ret return ret
} }
@ -249,7 +290,52 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error {
if err != nil && err != errDuplicateStreamData { if err != nil && err != errDuplicateStreamData {
return err return err
} }
s.newFrameOrErrCond.Signal() s.signalRead()
return nil
}
// signalRead performs a non-blocking send on the readChan
func (s *stream) signalRead() {
select {
case s.readChan <- struct{}{}:
default:
}
}
// signalRead performs a non-blocking send on the writeChan
func (s *stream) signalWrite() {
select {
case s.writeChan <- struct{}{}:
default:
}
}
func (s *stream) SetReadDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.readDeadline
s.readDeadline = t
s.mutex.Unlock()
// if the new deadline is before the currently set deadline, wake up Read()
if t.Before(oldDeadline) {
s.signalRead()
}
return nil
}
func (s *stream) SetWriteDeadline(t time.Time) error {
s.mutex.Lock()
oldDeadline := s.writeDeadline
s.writeDeadline = t
s.mutex.Unlock()
if t.Before(oldDeadline) {
s.signalWrite()
}
return nil
}
func (s *stream) SetDeadline(t time.Time) error {
_ = s.SetReadDeadline(t) // SetReadDeadline never errors
_ = s.SetWriteDeadline(t) // SetWriteDeadline never errors
return nil return nil
} }
@ -266,8 +352,8 @@ func (s *stream) Cancel(err error) {
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.newFrameOrErrCond.Signal() s.signalRead()
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
s.mutex.Unlock() s.mutex.Unlock()
} }
@ -282,8 +368,8 @@ func (s *stream) Reset(err error) {
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.newFrameOrErrCond.Signal() s.signalRead()
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
if s.shouldSendReset() { if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset) s.onReset(s.streamID, s.writeOffset)
@ -302,7 +388,7 @@ func (s *stream) RegisterRemoteError(err error) {
// errors must not be changed! // errors must not be changed!
if s.err == nil { if s.err == nil {
s.err = err s.err = err
s.doneWritingOrErrCond.Signal() s.signalWrite()
} }
if s.shouldSendReset() { if s.shouldSendReset() {
s.onReset(s.streamID, s.writeOffset) s.onReset(s.streamID, s.writeOffset)

View File

@ -4,8 +4,8 @@ import (
"errors" "errors"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type streamFrameSorter struct { type streamFrameSorter struct {

View File

@ -3,8 +3,8 @@ package quic
import ( import (
"github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/flowcontrol"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type streamFramer struct { type streamFramer struct {
@ -45,6 +45,28 @@ func (f *streamFramer) HasFramesForRetransmission() bool {
return len(f.retransmissionQueue) > 0 return len(f.retransmissionQueue) > 0
} }
func (f *streamFramer) HasCryptoStreamFrame() bool {
// TODO(#657): Flow control
cs, _ := f.streamsMap.GetOrOpenStream(1)
return cs.lenOfDataForWriting() > 0
}
// TODO(lclemente): This is somewhat duplicate with the normal path for generating frames.
// TODO(#657): Flow control
func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *frames.StreamFrame {
if !f.HasCryptoStreamFrame() {
return nil
}
cs, _ := f.streamsMap.GetOrOpenStream(1)
frame := &frames.StreamFrame{
StreamID: 1,
Offset: cs.writeOffset,
}
frameHeaderBytes, _ := frame.MinLength(protocol.VersionWhatever) // can never error
frame.Data = cs.getDataForWriting(maxLen - frameHeaderBytes)
return frame
}
func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*frames.StreamFrame, currentLen protocol.ByteCount) { func (f *streamFramer) maybePopFramesForRetransmission(maxLen protocol.ByteCount) (res []*frames.StreamFrame, currentLen protocol.ByteCount) {
for len(f.retransmissionQueue) > 0 { for len(f.retransmissionQueue) > 0 {
frame := f.retransmissionQueue[0] frame := f.retransmissionQueue[0]
@ -76,7 +98,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
var currentLen protocol.ByteCount var currentLen protocol.ByteCount
fn := func(s *stream) (bool, error) { fn := func(s *stream) (bool, error) {
if s == nil { if s == nil || s.streamID == 1 /* crypto stream is handled separately */ {
return true, nil return true, nil
} }
@ -90,7 +112,8 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
maxLen := maxBytes - currentLen - frameHeaderBytes maxLen := maxBytes - currentLen - frameHeaderBytes
var sendWindowSize protocol.ByteCount var sendWindowSize protocol.ByteCount
if s.lenOfDataForWriting() != 0 { lenStreamData := s.lenOfDataForWriting()
if lenStreamData != 0 {
sendWindowSize, _ = f.flowControlManager.SendWindowSize(s.streamID) sendWindowSize, _ = f.flowControlManager.SendWindowSize(s.streamID)
maxLen = utils.MinByteCount(maxLen, sendWindowSize) maxLen = utils.MinByteCount(maxLen, sendWindowSize)
} }
@ -99,7 +122,12 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
return true, nil return true, nil
} }
data := s.getDataForWriting(maxLen) var data []byte
if lenStreamData != 0 {
// Only getDataForWriting() if we didn't have data earlier, so that we
// don't send without FC approval (if a Write() raced).
data = s.getDataForWriting(maxLen)
}
// This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests. // This is unlikely, but check it nonetheless, the scheduler might have jumped in. Seems to happen in ~20% of cases in the tests.
shouldSendFin := s.shouldSendFin() shouldSendFin := s.shouldSendFin()

View File

@ -36,7 +36,7 @@ type streamsMap struct {
} }
type streamLambda func(*stream) (bool, error) type streamLambda func(*stream) (bool, error)
type newStreamLambda func(protocol.StreamID) (*stream, error) type newStreamLambda func(protocol.StreamID) *stream
var ( var (
errMapAccess = errors.New("streamsMap: Error accessing the streams map") errMapAccess = errors.New("streamsMap: Error accessing the streams map")
@ -83,15 +83,27 @@ func (m *streamsMap) GetOrOpenStream(id protocol.StreamID) (*stream, error) {
return s, nil return s, nil
} }
if id <= m.highestStreamOpenedByPeer { if m.perspective == protocol.PerspectiveServer {
return nil, nil if id%2 == 0 {
if id <= m.nextStream { // this is a server-side stream that we already opened. Must have been closed already
return nil, nil
}
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a client-side stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
} }
if m.perspective == protocol.PerspectiveClient {
if m.perspective == protocol.PerspectiveServer && id%2 == 0 { if id%2 == 1 {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from client-side", id)) if id <= m.nextStream { // this is a client-side stream that we already opened.
} return nil, nil
if m.perspective == protocol.PerspectiveClient && id%2 == 1 { }
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d from server-side", id))
}
if id <= m.highestStreamOpenedByPeer { // this is a server-side stream that doesn't exist anymore. Must have been closed already
return nil, nil
}
} }
// sid is the next stream that will be opened // sid is the next stream that will be opened
@ -120,11 +132,6 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer)) return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
} }
s, err := m.newStream(id)
if err != nil {
return nil, err
}
if m.perspective == protocol.PerspectiveServer { if m.perspective == protocol.PerspectiveServer {
m.numIncomingStreams++ m.numIncomingStreams++
} else { } else {
@ -135,6 +142,7 @@ func (m *streamsMap) openRemoteStream(id protocol.StreamID) (*stream, error) {
m.highestStreamOpenedByPeer = id m.highestStreamOpenedByPeer = id
} }
s := m.newStream(id)
m.putStream(s) m.putStream(s)
return s, nil return s, nil
} }
@ -145,11 +153,6 @@ func (m *streamsMap) openStreamImpl() (*stream, error) {
return nil, qerr.TooManyOpenStreams return nil, qerr.TooManyOpenStreams
} }
s, err := m.newStream(id)
if err != nil {
return nil, err
}
if m.perspective == protocol.PerspectiveServer { if m.perspective == protocol.PerspectiveServer {
m.numOutgoingStreams++ m.numOutgoingStreams++
} else { } else {
@ -157,6 +160,7 @@ func (m *streamsMap) openStreamImpl() (*stream, error) {
} }
m.nextStream += 2 m.nextStream += 2
s := m.newStream(id)
m.putStream(s) m.putStream(s)
return s, nil return s, nil
} }
@ -319,8 +323,11 @@ func (m *streamsMap) RemoveStream(id protocol.StreamID) error {
func (m *streamsMap) CloseWithError(err error) { func (m *streamsMap) CloseWithError(err error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
m.closeErr = err m.closeErr = err
m.nextStreamOrErrCond.Broadcast() m.nextStreamOrErrCond.Broadcast()
m.openStreamOrErrCond.Broadcast() m.openStreamOrErrCond.Broadcast()
m.mutex.Unlock() for _, s := range m.openStreams {
m.streams[s].Cancel(err)
}
} }

View File

@ -1,31 +0,0 @@
package quic
import (
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
)
type unpackedPacket struct {
encryptionLevel protocol.EncryptionLevel
frames []frames.Frame
}
func (u *unpackedPacket) IsRetransmittable() bool {
for _, f := range u.frames {
switch f.(type) {
case *frames.StreamFrame:
return true
case *frames.RstStreamFrame:
return true
case *frames.WindowUpdateFrame:
return true
case *frames.BlockedFrame:
return true
case *frames.PingFrame:
return true
case *frames.GoawayFrame:
return true
}
}
return false
}

View File

@ -4,6 +4,7 @@ package dns
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
"io" "io"
@ -70,6 +71,43 @@ func Exchange(m *Msg, a string) (r *Msg, err error) {
return r, err return r, err
} }
// ExchangeContext performs a synchronous UDP query, like Exchange. It
// additionally obeys deadlines from the passed Context.
func ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg, err error) {
// Combine context deadline with built-in timeout. Context chooses whichever
// is sooner.
timeoutCtx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
deadline, _ := timeoutCtx.Deadline()
co := new(Conn)
dialer := net.Dialer{}
co.Conn, err = dialer.DialContext(timeoutCtx, "udp", a)
if err != nil {
return nil, err
}
defer co.Conn.Close()
opt := m.IsEdns0()
// If EDNS0 is used use that for size.
if opt != nil && opt.UDPSize() >= MinMsgSize {
co.UDPSize = opt.UDPSize()
}
co.SetWriteDeadline(deadline)
if err = co.WriteMsg(m); err != nil {
return nil, err
}
co.SetReadDeadline(deadline)
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
}
return r, err
}
// ExchangeConn performs a synchronous query. It sends the message m via the connection // ExchangeConn performs a synchronous query. It sends the message m via the connection
// c and waits for a reply. The connection c is not closed by ExchangeConn. // c and waits for a reply. The connection c is not closed by ExchangeConn.
// This function is going away, but can easily be mimicked: // This function is going away, but can easily be mimicked:
@ -106,8 +144,18 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
// buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit // buffer, see SetEdns0. Messages without an OPT RR will fallback to the historic limit
// of 512 bytes. // of 512 bytes.
func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
return c.ExchangeContext(context.Background(), m, a)
}
// ExchangeContext acts like Exchange, but honors the deadline on the provided
// context, if present. If there is both a context deadline and a configured
// timeout on the client, the earliest of the two takes effect.
func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (
r *Msg,
rtt time.Duration,
err error) {
if !c.SingleInflight { if !c.SingleInflight {
return c.exchange(m, a) return c.exchange(ctx, m, a)
} }
// This adds a bunch of garbage, TODO(miek). // This adds a bunch of garbage, TODO(miek).
t := "nop" t := "nop"
@ -119,7 +167,7 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
cl = cl1 cl = cl1
} }
r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) { r, rtt, err, shared := c.group.Do(m.Question[0].Name+t+cl, func() (*Msg, time.Duration, error) {
return c.exchange(m, a) return c.exchange(ctx, m, a)
}) })
if r != nil && shared { if r != nil && shared {
r = r.Copy() r = r.Copy()
@ -154,7 +202,7 @@ func (c *Client) writeTimeout() time.Duration {
return dnsTimeout return dnsTimeout
} }
func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err error) { func (c *Client) exchange(ctx context.Context, m *Msg, a string) (r *Msg, rtt time.Duration, err error) {
var co *Conn var co *Conn
network := "udp" network := "udp"
tls := false tls := false
@ -180,10 +228,13 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
deadline = time.Now().Add(c.Timeout) deadline = time.Now().Add(c.Timeout)
} }
dialDeadline := deadlineOrTimeoutOrCtx(ctx, deadline, c.dialTimeout())
dialTimeout := dialDeadline.Sub(time.Now())
if tls { if tls {
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout()) co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, dialTimeout)
} else { } else {
co, err = DialTimeout(network, a, c.dialTimeout()) co, err = DialTimeout(network, a, dialTimeout)
} }
if err != nil { if err != nil {
@ -202,12 +253,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
} }
co.TsigSecret = c.TsigSecret co.TsigSecret = c.TsigSecret
co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout())) co.SetWriteDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.writeTimeout()))
if err = co.WriteMsg(m); err != nil { if err = co.WriteMsg(m); err != nil {
return nil, 0, err return nil, 0, err
} }
co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout())) co.SetReadDeadline(deadlineOrTimeoutOrCtx(ctx, deadline, c.readTimeout()))
r, err = co.ReadMsg() r, err = co.ReadMsg()
if err == nil && r.Id != m.Id { if err == nil && r.Id != m.Id {
err = ErrId err = ErrId
@ -459,9 +510,22 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
return conn, nil return conn, nil
} }
// deadlineOrTimeout chooses between the provided deadline and timeout
// by always preferring the deadline so long as it's non-zero (regardless
// of which is bigger), and returns the equivalent deadline value.
func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time { func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
if deadline.IsZero() { if deadline.IsZero() {
return time.Now().Add(timeout) return time.Now().Add(timeout)
} }
return deadline return deadline
} }
// deadlineOrTimeoutOrCtx returns the earliest of: a context deadline, or the
// output of deadlineOrtimeout.
func deadlineOrTimeoutOrCtx(ctx context.Context, deadline time.Time, timeout time.Duration) time.Time {
result := deadlineOrTimeout(deadline, timeout)
if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(result) {
result = ctxDeadline
}
return result
}

View File

@ -96,7 +96,7 @@ func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte,
return hdr, len(msg), msg, err return hdr, len(msg), msg, err
} }
msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength) msg, err = truncateMsgFromRdlength(msg, off, hdr.Rdlength)
return hdr, off, msg, nil return hdr, off, msg, err
} }
// pack packs an RR header, returning the offset to the end of the header. // pack packs an RR header, returning the offset to the end of the header.

42
vendor/github.com/miekg/dns/types.go generated vendored
View File

@ -115,27 +115,27 @@ const (
ClassNONE = 254 ClassNONE = 254
ClassANY = 255 ClassANY = 255
// Message Response Codes. // Message Response Codes, see https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml
RcodeSuccess = 0 RcodeSuccess = 0 // NoError - No Error [DNS]
RcodeFormatError = 1 RcodeFormatError = 1 // FormErr - Format Error [DNS]
RcodeServerFailure = 2 RcodeServerFailure = 2 // ServFail - Server Failure [DNS]
RcodeNameError = 3 RcodeNameError = 3 // NXDomain - Non-Existent Domain [DNS]
RcodeNotImplemented = 4 RcodeNotImplemented = 4 // NotImp - Not Implemented [DNS]
RcodeRefused = 5 RcodeRefused = 5 // Refused - Query Refused [DNS]
RcodeYXDomain = 6 RcodeYXDomain = 6 // YXDomain - Name Exists when it should not [DNS Update]
RcodeYXRrset = 7 RcodeYXRrset = 7 // YXRRSet - RR Set Exists when it should not [DNS Update]
RcodeNXRrset = 8 RcodeNXRrset = 8 // NXRRSet - RR Set that should exist does not [DNS Update]
RcodeNotAuth = 9 RcodeNotAuth = 9 // NotAuth - Server Not Authoritative for zone [DNS Update]
RcodeNotZone = 10 RcodeNotZone = 10 // NotZone - Name not contained in zone [DNS Update/TSIG]
RcodeBadSig = 16 // TSIG RcodeBadSig = 16 // BADSIG - TSIG Signature Failure [TSIG]
RcodeBadVers = 16 // EDNS0 RcodeBadVers = 16 // BADVERS - Bad OPT Version [EDNS0]
RcodeBadKey = 17 RcodeBadKey = 17 // BADKEY - Key not recognized [TSIG]
RcodeBadTime = 18 RcodeBadTime = 18 // BADTIME - Signature out of time window [TSIG]
RcodeBadMode = 19 // TKEY RcodeBadMode = 19 // BADMODE - Bad TKEY Mode [TKEY]
RcodeBadName = 20 RcodeBadName = 20 // BADNAME - Duplicate key name [TKEY]
RcodeBadAlg = 21 RcodeBadAlg = 21 // BADALG - Algorithm not supported [TKEY]
RcodeBadTrunc = 22 // TSIG RcodeBadTrunc = 22 // BADTRUNC - Bad Truncation [TSIG]
RcodeBadCookie = 23 // DNS Cookies RcodeBadCookie = 23 // BADCOOKIE - Bad/missing Server Cookie [DNS Cookies]
// Message Opcodes. There is no 3. // Message Opcodes. There is no 3.
OpcodeQuery = 0 OpcodeQuery = 0

11
vendor/github.com/miekg/dns/xfr.go generated vendored
View File

@ -1,6 +1,7 @@
package dns package dns
import ( import (
"fmt"
"time" "time"
) )
@ -81,6 +82,10 @@ func (t *Transfer) inAxfr(id uint16, c chan *Envelope) {
return return
} }
if first { if first {
if in.Rcode != RcodeSuccess {
c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
return
}
if !isSOAFirst(in) { if !isSOAFirst(in) {
c <- &Envelope{in.Answer, ErrSoa} c <- &Envelope{in.Answer, ErrSoa}
return return
@ -126,6 +131,10 @@ func (t *Transfer) inIxfr(id uint16, c chan *Envelope) {
return return
} }
if first { if first {
if in.Rcode != RcodeSuccess {
c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
return
}
// A single SOA RR signals "no changes" // A single SOA RR signals "no changes"
if len(in.Answer) == 1 && isSOAFirst(in) { if len(in.Answer) == 1 && isSOAFirst(in) {
c <- &Envelope{in.Answer, nil} c <- &Envelope{in.Answer, nil}
@ -242,3 +251,5 @@ func isSOALast(in *Msg) bool {
} }
return false return false
} }
const errXFR = "bad xfr rcode: %d"

View File

@ -488,6 +488,7 @@ func link(p *parser, out *bytes.Buffer, data []byte, offset int) int {
} }
p.notes = append(p.notes, ref) p.notes = append(p.notes, ref)
p.notesRecord[string(ref.link)] = struct{}{}
link = ref.link link = ref.link
title = ref.title title = ref.title
@ -498,9 +499,10 @@ func link(p *parser, out *bytes.Buffer, data []byte, offset int) int {
return 0 return 0
} }
if t == linkDeferredFootnote { if t == linkDeferredFootnote && !p.isFootnote(lr) {
lr.noteId = len(p.notes) + 1 lr.noteId = len(p.notes) + 1
p.notes = append(p.notes, lr) p.notes = append(p.notes, lr)
p.notesRecord[string(lr.link)] = struct{}{}
} }
// keep link and title from reference // keep link and title from reference

View File

@ -218,7 +218,8 @@ type parser struct {
// Footnotes need to be ordered as well as available to quickly check for // Footnotes need to be ordered as well as available to quickly check for
// presence. If a ref is also a footnote, it's stored both in refs and here // presence. If a ref is also a footnote, it's stored both in refs and here
// in notes. Slice is nil if footnotes not enabled. // in notes. Slice is nil if footnotes not enabled.
notes []*reference notes []*reference
notesRecord map[string]struct{}
} }
func (p *parser) getRef(refid string) (ref *reference, found bool) { func (p *parser) getRef(refid string) (ref *reference, found bool) {
@ -241,6 +242,11 @@ func (p *parser) getRef(refid string) (ref *reference, found bool) {
return ref, found return ref, found
} }
func (p *parser) isFootnote(ref *reference) bool {
_, ok := p.notesRecord[string(ref.link)]
return ok
}
// //
// //
// Public interface // Public interface
@ -376,6 +382,7 @@ func MarkdownOptions(input []byte, renderer Renderer, opts Options) []byte {
if extensions&EXTENSION_FOOTNOTES != 0 { if extensions&EXTENSION_FOOTNOTES != 0 {
p.notes = make([]*reference, 0) p.notes = make([]*reference, 0)
p.notesRecord = make(map[string]struct{})
} }
first := firstPass(p, input) first := firstPass(p, input)

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"runtime" "runtime"
"strings" "strings"
@ -15,7 +16,17 @@ import (
var UserAgent string var UserAgent string
// HTTPClient is an HTTP client with a reasonable timeout value. // HTTPClient is an HTTP client with a reasonable timeout value.
var HTTPClient = http.Client{Timeout: 10 * time.Second} var HTTPClient = http.Client{
Transport: &http.Transport{
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 15 * time.Second,
ResponseHeaderTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
const ( const (
// defaultGoUserAgent is the Go HTTP package user agent string. Too // defaultGoUserAgent is the Go HTTP package user agent string. Too

View File

@ -207,7 +207,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
return nil, "", responseError(res) return nil, "", responseError(res)
} }
curl := res.Header.Get("location") // cert permanent URL curl := res.Header.Get("Location") // cert permanent URL
if res.ContentLength == 0 { if res.ContentLength == 0 {
// no cert in the body; poll until we get it // no cert in the body; poll until we get it
cert, err := c.FetchCert(ctx, curl, bundle) cert, err := c.FetchCert(ctx, curl, bundle)
@ -240,7 +240,7 @@ func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]by
if res.StatusCode > 299 { if res.StatusCode > 299 {
return nil, responseError(res) return nil, responseError(res)
} }
d := retryAfter(res.Header.Get("retry-after"), 3*time.Second) d := retryAfter(res.Header.Get("Retry-After"), 3*time.Second)
select { select {
case <-time.After(d): case <-time.After(d):
// retry // retry
@ -444,7 +444,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
if err != nil { if err != nil {
return nil, err return nil, err
} }
retry := res.Header.Get("retry-after") retry := res.Header.Get("Retry-After")
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted { if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
res.Body.Close() res.Body.Close()
if err := sleep(retry, 1); err != nil { if err := sleep(retry, 1); err != nil {
@ -703,7 +703,7 @@ func (c *Client) retryPostJWS(ctx context.Context, key crypto.Signer, url string
// clear any nonces that we might've stored that might now be // clear any nonces that we might've stored that might now be
// considered bad // considered bad
c.clearNonces() c.clearNonces()
retry := res.Header.Get("retry-after") retry := res.Header.Get("Retry-After")
if err := sleep(retry, 1); err != nil { if err := sleep(retry, 1); err != nil {
return nil, err return nil, err
} }

View File

@ -36,6 +36,9 @@ import (
// operating system-specific cache or temp directory. This may not // operating system-specific cache or temp directory. This may not
// be suitable for servers spanning multiple machines. // be suitable for servers spanning multiple machines.
// //
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted // The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS // connections. The returned *tls.Conn are returned before their TLS
// handshake has completed. // handshake has completed.
@ -58,6 +61,9 @@ func NewListener(domains ...string) net.Listener {
// Listener listens on the standard TLS port (443) on all interfaces // Listener listens on the standard TLS port (443) on all interfaces
// and returns a net.Listener returning *tls.Conn connections. // and returns a net.Listener returning *tls.Conn connections.
// //
// The returned listener uses a *tls.Config that enables HTTP/2, and
// should only be used with servers that support HTTP/2.
//
// The returned Listener also enables TCP keep-alives on the accepted // The returned Listener also enables TCP keep-alives on the accepted
// connections. The returned *tls.Conn are returned before their TLS // connections. The returned *tls.Conn are returned before their TLS
// handshake has completed. // handshake has completed.
@ -68,7 +74,8 @@ func (m *Manager) Listener() net.Listener {
ln := &listener{ ln := &listener{
m: m, m: m,
conf: &tls.Config{ conf: &tls.Config{
GetCertificate: m.GetCertificate, // bonus: panic on nil m GetCertificate: m.GetCertificate, // bonus: panic on nil m
NextProtos: []string{"h2", "http/1.1"}, // Enable HTTP/2
}, },
} }
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443") ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")

View File

@ -3,6 +3,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
#define REDMASK51 0x0007FFFFFFFFFFFF #define REDMASK51 0x0007FFFFFFFFFFFF

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package curve25519 provides an implementation of scalar multiplication on // Package curve25519 provides an implementation of scalar multiplication on
// the elliptic curve known as curve25519. See http://cr.yp.to/ecdh.html // the elliptic curve known as curve25519. See https://cr.yp.to/ecdh.html
package curve25519 // import "golang.org/x/crypto/curve25519" package curve25519 // import "golang.org/x/crypto/curve25519"
// basePoint is the x coordinate of the generator of the curve. // basePoint is the x coordinate of the generator of the curve.

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

View File

@ -3,7 +3,7 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This code was translated into a form compatible with 6a from the public // This code was translated into a form compatible with 6a from the public
// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html // domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html
// +build amd64,!gccgo,!appengine // +build amd64,!gccgo,!appengine

Some files were not shown because too many files have changed in this diff Show More