Handle control frames that arrive in multiple TCP fragments

This commit is contained in:
Kovid Goyal 2015-10-26 15:36:06 +05:30
parent c9adbebe7f
commit 1555725117

View File

@ -63,6 +63,7 @@ class ReadFrame(object): # {{{
def reset(self): def reset(self):
self.state = self.read_header0 self.state = self.read_header0
self.control_buf = []
def __call__(self, conn): def __call__(self, conn):
return self.state(conn) return self.state(conn)
@ -80,11 +81,12 @@ class ReadFrame(object): # {{{
self.opcode = b & 0b1111 self.opcode = b & 0b1111
self.state = self.read_header1 self.state = self.read_header1
self.is_control = self.opcode in CONTROL_CODES
if self.opcode not in ALL_CODES: if self.opcode not in ALL_CODES:
conn.log.error('Unknown OPCODE from client: %r' % self.opcode) conn.log.error('Unknown OPCODE from client: %r' % self.opcode)
conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode) conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode)
return return
if not self.fin and self.opcode in CONTROL_CODES: if not self.fin and self.is_control:
conn.log.error('Fragmented control frame from client') conn.log.error('Fragmented control frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame') conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame')
return return
@ -101,7 +103,7 @@ class ReadFrame(object): # {{{
self.reset() self.reset()
return return
self.payload_length = b & 0b01111111 self.payload_length = b & 0b01111111
if self.opcode in CONTROL_CODES and self.payload_length > 125: if self.is_control and self.payload_length > 125:
conn.log.error('Too large control frame from client') conn.log.error('Too large control frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large') conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large')
self.reset() self.reset()
@ -155,6 +157,12 @@ class ReadFrame(object): # {{{
data = b'' data = b''
self.pos += len(data) self.pos += len(data)
frame_finished = self.pos >= self.payload_length frame_finished = self.pos >= self.payload_length
if self.is_control:
self.control_buf.append(data)
if frame_finished:
data = b''.join(self.control_buf)
else:
return
conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, self.fin) conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, self.fin)
self.frame_starting = False self.frame_starting = False
if frame_finished: if frame_finished: