Change the control frames queue to FIFO

Also properly handle fragmented control frames
This commit is contained in:
Kovid Goyal 2015-10-25 21:42:23 +05:30
parent f908df4991
commit 51f0e93be5
2 changed files with 21 additions and 6 deletions

View File

@ -244,3 +244,10 @@ class WebSocketTest(BaseTest):
for opcode in (3, 4, 5, 6, 7, 11, 12, 13, 14, 15): for opcode in (3, 4, 5, 6, 7, 11, 12, 13, 14, 15):
client = server.connect() client = server.connect()
self.simple_test(client, [{'opcode':opcode}], [], close_code=PROTOCOL_ERROR, send_close=False) self.simple_test(client, [{'opcode':opcode}], [], close_code=PROTOCOL_ERROR, send_close=False)
for opcode in (PING, PONG):
client = server.connect()
self.simple_test(client, [
{'opcode':opcode, 'payload':'f1', 'fin':0}, {'opcode':opcode, 'payload':'f2'}
], close_code=PROTOCOL_ERROR, send_close=False)

View File

@ -7,6 +7,7 @@ from __future__ import (unicode_literals, division, absolute_import,
import codecs, httplib, struct, os, weakref, repr as reprlib, time, socket import codecs, httplib, struct, os, weakref, repr as reprlib, time, socket
from base64 import standard_b64encode from base64 import standard_b64encode
from collections import deque
from functools import partial from functools import partial
from hashlib import sha1 from hashlib import sha1
from io import BytesIO from io import BytesIO
@ -63,7 +64,7 @@ class ReadFrame(object): # {{{
if not data: if not data:
return return
b = ord(data) b = ord(data)
self.fin = b & 0b10000000 self.fin = bool(b & 0b10000000)
if b & 0b01110000: if b & 0b01110000:
conn.log.error('RSV bits set in frame from client') conn.log.error('RSV bits set in frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set') conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set')
@ -75,6 +76,10 @@ class ReadFrame(object): # {{{
conn.log.error('Unknown OPCODE from client: %r' % self.opcode) conn.log.error('Unknown OPCODE from client: %r' % self.opcode)
conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode) conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode)
return return
if not self.fin and self.opcode in CONTROL_CODES:
conn.log.error('Fragmented control frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame')
return
def read_header1(self, conn): def read_header1(self, conn):
data = conn.recv(1) data = conn.recv(1)
@ -85,11 +90,13 @@ class ReadFrame(object): # {{{
if not self.mask: if not self.mask:
conn.log.error('Unmasked packet from client') conn.log.error('Unmasked packet from client')
conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed') conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed')
self.reset()
return return
self.payload_length = b & 0b01111111 self.payload_length = b & 0b01111111
if self.opcode in (PING, PONG) and self.payload_length > 125: if self.opcode in (PING, PONG) and self.payload_length > 125:
conn.log.error('Too large ping packet from client') conn.log.error('Too large ping packet from client')
conn.websocket_close(PROTOCOL_ERROR, 'Ping packet too large') conn.websocket_close(PROTOCOL_ERROR, 'Ping packet too large')
self.reset()
return return
self.mask_buf = b'' self.mask_buf = b''
if self.payload_length == 126: if self.payload_length == 126:
@ -128,7 +135,7 @@ class ReadFrame(object): # {{{
self.pos = 0 self.pos = 0
self.frame_starting = True self.frame_starting = True
if self.payload_length == 0: if self.payload_length == 0:
conn.ws_data_received(b'', self.opcode, True, True, bool(self.fin)) conn.ws_data_received(b'', self.opcode, True, True, self.fin)
self.reset() self.reset()
def read_payload(self, conn): def read_payload(self, conn):
@ -145,7 +152,7 @@ class ReadFrame(object): # {{{
data = bytes(data) data = bytes(data)
self.pos += len(data) self.pos += len(data)
frame_finished = self.pos >= self.payload_length frame_finished = self.pos >= self.payload_length
conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, bool(self.fin)) conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, self.fin)
self.frame_starting = False self.frame_starting = False
if frame_finished: if frame_finished:
self.reset() self.reset()
@ -219,7 +226,7 @@ class WebSocketConnection(HTTPConnection):
global conn_id global conn_id
HTTPConnection.__init__(self, *args, **kwargs) HTTPConnection.__init__(self, *args, **kwargs)
self.sendq = Queue() self.sendq = Queue()
self.control_frames = [] self.control_frames = deque()
self.cf_lock = Lock() self.cf_lock = Lock()
self.sending = None self.sending = None
self.send_buf = None self.send_buf = None
@ -377,7 +384,7 @@ class WebSocketConnection(HTTPConnection):
else: else:
with self.cf_lock: with self.cf_lock:
try: try:
self.send_buf = self.control_frames.pop() self.send_buf = self.control_frames.popleft()
except IndexError: except IndexError:
if self.sending is not None: if self.sending is not None:
self.send_buf = self.sending.create_frame() self.send_buf = self.sending.create_frame()
@ -396,6 +403,7 @@ class WebSocketConnection(HTTPConnection):
try: try:
if self.send_buf is None: if self.send_buf is None:
self.websocket_close(SHUTTING_DOWN, 'Shutting down') self.websocket_close(SHUTTING_DOWN, 'Shutting down')
with self.cf_lock:
self.write(self.control_frames.pop()) self.write(self.control_frames.pop())
except Exception: except Exception:
pass pass