mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-08 10:44:09 -04:00
Return protocol error when RSV is set
This commit is contained in:
parent
f4b5d451fd
commit
b6eceba62b
@ -130,8 +130,8 @@ class WSClient(object):
|
||||
break
|
||||
self.socket.sendall(frame.getvalue())
|
||||
|
||||
def write_frame(self, fin, opcode, payload=b'', rsv=0, mask=True):
|
||||
frame = create_frame(fin, opcode, payload, rsv=rsv, mask=self.mask if mask else None)
|
||||
def write_frame(self, fin=1, opcode=CLOSE, payload=b'', rsv=0, mask=True):
|
||||
frame = create_frame(fin, opcode, payload, rsv=(rsv << 4), mask=self.mask if mask else None)
|
||||
self.socket.sendall(frame)
|
||||
|
||||
def write_close(self, code, reason=b''):
|
||||
@ -196,7 +196,10 @@ class WebSocketTest(BaseTest):
|
||||
|
||||
def simple_test(self, client, msgs, expected=(), close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'):
|
||||
for msg in msgs:
|
||||
client.write_message(msg)
|
||||
if isinstance(msg, dict):
|
||||
client.write_frame(**msg)
|
||||
else:
|
||||
client.write_message(msg)
|
||||
for ex in expected:
|
||||
if isinstance(ex, type('')):
|
||||
ex = TEXT, ex
|
||||
@ -233,3 +236,11 @@ class WebSocketTest(BaseTest):
|
||||
for payload in (b'', b'pong'):
|
||||
client = server.connect()
|
||||
self.simple_test(client, [(PONG, payload)], [])
|
||||
|
||||
with server.silence_log:
|
||||
for rsv in xrange(1, 7):
|
||||
client = server.connect()
|
||||
self.simple_test(client, [{'rsv':rsv, 'opcode':BINARY}], [], close_code=PROTOCOL_ERROR, send_close=False)
|
||||
for opcode in (3, 4, 5, 6, 7, 11, 12, 13, 14, 15):
|
||||
client = server.connect()
|
||||
self.simple_test(client, [{'opcode':opcode}], [], close_code=PROTOCOL_ERROR, send_close=False)
|
||||
|
@ -64,6 +64,11 @@ class ReadFrame(object): # {{{
|
||||
return
|
||||
b = ord(data)
|
||||
self.fin = b & 0b10000000
|
||||
if b & 0b01110000:
|
||||
conn.log.error('RSV bits set in frame from client')
|
||||
conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set')
|
||||
return
|
||||
|
||||
self.opcode = b & 0b1111
|
||||
self.state = self.read_header1
|
||||
if self.opcode not in ALL_CODES:
|
||||
@ -153,8 +158,8 @@ def create_frame(fin, opcode, payload, mask=None, rsv=0):
|
||||
payload = payload.encode('utf-8')
|
||||
l = len(payload)
|
||||
opcode &= 0b1111
|
||||
b1 = opcode | (0b10000000 if fin else 0)
|
||||
b2 = rsv | (0 if mask is None else 0b10000000)
|
||||
b1 = opcode | (0b10000000 if fin else 0) | (rsv & 0b01110000)
|
||||
b2 = 0 if mask is None else 0b10000000
|
||||
if l < 126:
|
||||
header = bytes(bytearray((b1, b2 | l)))
|
||||
elif 126 <= l <= 65535:
|
||||
|
Loading…
x
Reference in New Issue
Block a user