tcp_transport/ws_client: websockets now correctly handle messages longer than buffer

transport_ws can now be read multiple times in a row to read frames larger than the buffer.

Added reporting of total payload length and offset to the user in websocket_client.

Added local example test for long messages.

Closes IDF-1083
This commit is contained in:
Marius Vikhammer
2020-03-20 11:07:07 +08:00
committed by bot
parent d6ef9d73bb
commit b56012783c
7 changed files with 305 additions and 109 deletions

View File

@@ -2,11 +2,15 @@ from __future__ import print_function
from __future__ import unicode_literals
import re
import os
import sys
import socket
import select
import hashlib
import base64
import sys
from threading import Thread
import queue
import random
import string
from threading import Thread, Event
try:
import IDF
@@ -20,8 +24,6 @@ except Exception:
sys.path.insert(0, test_fw_path)
import IDF
import DUT
def get_my_ip():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -43,7 +45,10 @@ class Websocket:
def __init__(self, port):
self.port = port
self.socket = socket.socket()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.settimeout(10.0)
self.send_q = queue.Queue()
self.shutdown = Event()
def __enter__(self):
try:
@@ -56,23 +61,27 @@ class Websocket:
self.server_thread = Thread(target=self.run_server)
self.server_thread.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.shutdown.set()
self.server_thread.join()
self.socket.close()
self.conn.close()
def run_server(self):
self.conn, address = self.socket.accept() # accept new connection
self.conn.settimeout(10.0)
self.socket.settimeout(10.0)
print("Connection from: {}".format(address))
self.establish_connection()
# Echo data until client closes connection
self.echo_data()
print("WS established")
# Handle connection until client closes it, will echo any data received and send data from send_q queue
self.handle_conn()
def establish_connection(self):
while True:
while not self.shutdown.is_set():
try:
# receive data stream. it won't accept data packet greater than 1024 bytes
data = self.conn.recv(1024).decode()
@@ -83,6 +92,7 @@ class Websocket:
if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
self.handshake(data)
return
except socket.error as err:
print("Unable to establish a websocket connection: {}, {}".format(err))
raise
@@ -107,26 +117,46 @@ class Websocket:
self.conn.send(resp.encode())
def echo_data(self):
while(True):
def handle_conn(self):
while not self.shutdown.is_set():
r,w,e = select.select([self.conn], [], [], 1)
try:
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
if not header:
# exit if data is not received
return
if self.conn in r:
self.echo_data()
# Remove mask bit
payload_len = ~(1 << 7) & header[1]
if not self.send_q.empty():
self._send_data_(self.send_q.get())
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
frame = header + payload
decoded_payload = self.decode_frame(frame)
echo_frame = self.encode_frame(decoded_payload)
self.conn.send(echo_frame)
except socket.error as err:
print("Stopped echoing data: {}".format(err))
raise
def echo_data(self):
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
if not header:
# exit if socket closed by peer
return
# Remove mask bit
payload_len = ~(1 << 7) & header[1]
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
if not payload:
# exit if socket closed by peer
return
frame = header + payload
decoded_payload = self.decode_frame(frame)
print("Sending echo...")
self._send_data_(decoded_payload)
def _send_data_(self, data):
frame = self.encode_frame(data)
self.conn.send(frame)
def send_data(self, data):
self.send_q.put(data.encode())
def decode_frame(self, frame):
# Mask out MASK bit from payload length, this len is only valid for short messages (<126)
@@ -147,7 +177,17 @@ class Websocket:
header = (1 << 7) | (1 << 0)
frame = bytearray([header])
frame.append(len(payload))
payload_len = len(payload)
# If payload len is longer than 125 then the next 16 bits are used to encode length
if payload_len > 125:
frame.append(126)
frame.append(payload_len >> 8)
frame.append(0xFF & payload_len)
else:
frame.append(payload_len)
frame += payload
return frame
@@ -156,8 +196,27 @@ class Websocket:
def test_echo(dut):
dut.expect("WEBSOCKET_EVENT_CONNECTED")
for i in range(0, 10):
dut.expect(re.compile(r"Received=hello (\d)"))
dut.expect("Websocket Stopped")
dut.expect(re.compile(r"Received=hello (\d)"), timeout=30)
print("All echos received")
def test_recv_long_msg(dut, websocket, msg_len, repeats):
send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
for _ in range(repeats):
websocket.send_data(send_msg)
recv_msg = ''
while len(recv_msg) < msg_len:
# Filter out color encoding
match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0]
recv_msg += match
if recv_msg == send_msg:
print("Sent message and received message are equal")
else:
raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\
\nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
@IDF.idf_example_test(env_tag="Example_WIFI")
@@ -191,12 +250,14 @@ def test_examples_protocol_websocket(env, extra_data):
if uri_from_stdin:
server_port = 4455
with Websocket(server_port):
with Websocket(server_port) as ws:
uri = "ws://{}:{}".format(get_my_ip(), server_port)
print("DUT connecting to {}".format(uri))
dut1.expect("Please enter uri of websocket endpoint", timeout=30)
dut1.write(uri)
test_echo(dut1)
# Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
test_recv_long_msg(dut1, ws, 2000, 3)
else:
print("DUT connecting to {}".format(uri))