mirror of
				https://github.com/espressif/esp-idf.git
				synced 2025-11-04 09:01:40 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			127 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			127 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from __future__ import print_function, unicode_literals
 | 
						|
 | 
						|
import os
 | 
						|
import re
 | 
						|
import socket
 | 
						|
import ssl
 | 
						|
from threading import Event, Thread
 | 
						|
 | 
						|
import ttfw_idf
 | 
						|
 | 
						|
SERVER_CERTS_DIR = 'server_certs/'
 | 
						|
 | 
						|
 | 
						|
def _path(f):
 | 
						|
    return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)
 | 
						|
 | 
						|
 | 
						|
def get_my_ip():
 | 
						|
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
						|
    try:
 | 
						|
        # doesn't even have to be reachable
 | 
						|
        s.connect(('10.255.255.255', 1))
 | 
						|
        IP = s.getsockname()[0]
 | 
						|
    except socket.error:
 | 
						|
        IP = '127.0.0.1'
 | 
						|
    finally:
 | 
						|
        s.close()
 | 
						|
    return IP
 | 
						|
 | 
						|
 | 
						|
# Simple TLS server
 | 
						|
class TlsServer:
 | 
						|
 | 
						|
    def __init__(self, port, negotiated_protocol=ssl.PROTOCOL_TLSv1):
 | 
						|
        self.port = port
 | 
						|
        self.socket = socket.socket()
 | 
						|
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 | 
						|
        self.socket.settimeout(20.0)
 | 
						|
        self.shutdown = Event()
 | 
						|
        self.negotiated_protocol = negotiated_protocol
 | 
						|
        self.conn = None
 | 
						|
        self.ssl_error = None
 | 
						|
        self.server_thread = None
 | 
						|
 | 
						|
    def __enter__(self):
 | 
						|
        try:
 | 
						|
            self.socket.bind(('', self.port))
 | 
						|
        except socket.error as e:
 | 
						|
            print('Bind failed:{}'.format(e))
 | 
						|
            raise
 | 
						|
 | 
						|
        self.socket.listen(1)
 | 
						|
        self.server_thread = Thread(target=self.run_server)
 | 
						|
        self.server_thread.start()
 | 
						|
 | 
						|
        return self
 | 
						|
 | 
						|
    def __exit__(self, exc_type, exc_value, traceback):
 | 
						|
        self.shutdown.set()
 | 
						|
        self.server_thread.join()
 | 
						|
        self.socket.close()
 | 
						|
        if (self.conn is not None):
 | 
						|
            self.conn.close()
 | 
						|
 | 
						|
    def run_server(self):
 | 
						|
        ctx = ssl.SSLContext(self.negotiated_protocol)
 | 
						|
        ctx.load_cert_chain(certfile=_path(SERVER_CERTS_DIR + 'ca.crt'), keyfile=_path(SERVER_CERTS_DIR + 'ca.key'))
 | 
						|
        self.socket = ctx.wrap_socket(self.socket, server_side=True)
 | 
						|
        try:
 | 
						|
            print('Listening socket')
 | 
						|
            self.conn, address = self.socket.accept()  # accept new connection
 | 
						|
            self.socket.settimeout(20.0)
 | 
						|
            print(' - connection from: {}'.format(address))
 | 
						|
        except ssl.SSLError as e:
 | 
						|
            self.conn = None
 | 
						|
            self.ssl_error = str(e)
 | 
						|
            print(' - SSLError: {}'.format(str(e)))
 | 
						|
 | 
						|
 | 
						|
def test_echo(dut):
 | 
						|
    dut.expect('SSL Connection Succeed')
 | 
						|
    print('SSL Connection Succeed')
 | 
						|
 | 
						|
 | 
						|
@ttfw_idf.idf_example_test(env_tag='Example_WIFI')
 | 
						|
def test_example_protocol_openssl_client(env, extra_data):
 | 
						|
    """
 | 
						|
     steps:
 | 
						|
       1. join AP
 | 
						|
       2. connect to uri "xxxx.xxxx.xxxx.xxxx:port"
 | 
						|
       3. send and receive data
 | 
						|
    """
 | 
						|
    dut1 = env.get_dut('openssl_client', 'examples/protocols/openssl_client', dut_class=ttfw_idf.ESP32DUT)
 | 
						|
    # check and log bin size
 | 
						|
    binary_file = os.path.join(dut1.app.binary_path, 'openssl_client.bin')
 | 
						|
    binary_size = os.path.getsize(binary_file)
 | 
						|
    ttfw_idf.log_performance('openssl_client_bin_size', '{}KB'.format(binary_size // 1024))
 | 
						|
 | 
						|
    try:
 | 
						|
        if 'CONFIG_EXAMPLE_OPENSSL_CLIENT_URI_FROM_STDIN' in dut1.app.get_sdkconfig():
 | 
						|
            uri_from_stdin = True
 | 
						|
        else:
 | 
						|
            uri = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_OPENSSL_CLIENT_TARGET_DOMAIN'].strip('"')
 | 
						|
            uri_from_stdin = False
 | 
						|
    except Exception:
 | 
						|
        print('ENV_TEST_FAILURE: Cannot find target domain in sdkconfig')
 | 
						|
        raise
 | 
						|
 | 
						|
    # start test
 | 
						|
    dut1.start_app()
 | 
						|
    dut1.expect(re.compile(r' IPv4 address: ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)'), timeout=30)
 | 
						|
    ip = get_my_ip()
 | 
						|
 | 
						|
    if uri_from_stdin:
 | 
						|
        server_port = 2222
 | 
						|
        with TlsServer(server_port, negotiated_protocol=ssl.PROTOCOL_TLSv1_1):
 | 
						|
            print('Starting test')
 | 
						|
            dut1.write('{} {}'.format(ip, server_port))
 | 
						|
            dut1.expect(re.compile('SSL Connection Succeed'), timeout=10)
 | 
						|
    else:
 | 
						|
        print('DUT connecting to {}'.format(uri))
 | 
						|
        test_echo(dut1)
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    test_example_protocol_openssl_client()
 |