forked from wolfSSL/wolfssl
adds initial code for SSLSocket
This commit is contained in:
@ -154,7 +154,11 @@ class SSLContext(object):
|
|||||||
private key in.
|
private key in.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if certfile:
|
if password is not None:
|
||||||
|
raise NotImplementedError("password callback support not "
|
||||||
|
"implemented yet")
|
||||||
|
|
||||||
|
if certfile is not None:
|
||||||
ret = _lib.wolfSSL_CTX_use_certificate_chain_file(
|
ret = _lib.wolfSSL_CTX_use_certificate_chain_file(
|
||||||
self.native_object, t2b(certfile))
|
self.native_object, t2b(certfile))
|
||||||
if ret != _SSL_SUCCESS:
|
if ret != _SSL_SUCCESS:
|
||||||
@ -162,7 +166,7 @@ class SSLContext(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError("certfile should be a valid filesystem path")
|
raise TypeError("certfile should be a valid filesystem path")
|
||||||
|
|
||||||
if keyfile:
|
if keyfile is not None:
|
||||||
ret = _lib.wolfSSL_CTX_use_PrivateKey_file(
|
ret = _lib.wolfSSL_CTX_use_PrivateKey_file(
|
||||||
self.native_object, t2b(keyfile), _SSL_FILETYPE_PEM)
|
self.native_object, t2b(keyfile), _SSL_FILETYPE_PEM)
|
||||||
if ret != _SSL_SUCCESS:
|
if ret != _SSL_SUCCESS:
|
||||||
@ -185,7 +189,7 @@ class SSLContext(object):
|
|||||||
if cafile is None and capath is None and cadata is None:
|
if cafile is None and capath is None and cadata is None:
|
||||||
raise TypeError("cafile, capath and cadata cannot be all omitted")
|
raise TypeError("cafile, capath and cadata cannot be all omitted")
|
||||||
|
|
||||||
if cafile or capath:
|
if cafile is not None or capath is not None:
|
||||||
ret = _lib.wolfSSL_CTX_load_verify_locations(
|
ret = _lib.wolfSSL_CTX_load_verify_locations(
|
||||||
self.native_object,
|
self.native_object,
|
||||||
t2b(cafile) if cafile else _ffi.NULL,
|
t2b(cafile) if cafile else _ffi.NULL,
|
||||||
@ -194,13 +198,14 @@ class SSLContext(object):
|
|||||||
if ret != _SSL_SUCCESS:
|
if ret != _SSL_SUCCESS:
|
||||||
raise SSLError("Unnable to load verify locations. Err %d" % ret)
|
raise SSLError("Unnable to load verify locations. Err %d" % ret)
|
||||||
|
|
||||||
if cadata:
|
if cadata is not None:
|
||||||
ret = _lib.wolfSSL_CTX_load_verify_buffer(
|
ret = _lib.wolfSSL_CTX_load_verify_buffer(
|
||||||
self.native_object, t2b(cadata), len(cadata), _SSL_FILETYPE_PEM)
|
self.native_object, t2b(cadata), len(cadata), _SSL_FILETYPE_PEM)
|
||||||
|
|
||||||
if ret != _SSL_SUCCESS:
|
if ret != _SSL_SUCCESS:
|
||||||
raise SSLError("Unnable to load verify locations. Err %d" % ret)
|
raise SSLError("Unnable to load verify locations. Err %d" % ret)
|
||||||
|
|
||||||
|
|
||||||
class SSLSocket(socket):
|
class SSLSocket(socket):
|
||||||
"""
|
"""
|
||||||
This class implements a subtype of socket.socket that wraps the
|
This class implements a subtype of socket.socket that wraps the
|
||||||
@ -215,7 +220,353 @@ class SSLSocket(socket):
|
|||||||
sock_type=SOCK_STREAM, proto=0, fileno=None,
|
sock_type=SOCK_STREAM, proto=0, fileno=None,
|
||||||
suppress_ragged_eofs=True, ciphers=None,
|
suppress_ragged_eofs=True, ciphers=None,
|
||||||
_context=None):
|
_context=None):
|
||||||
pass
|
|
||||||
|
# set options
|
||||||
|
self.do_handshake_on_connect = do_handshake_on_connect
|
||||||
|
self.suppress_ragged_eofs = suppress_ragged_eofs
|
||||||
|
self.server_side = server_side
|
||||||
|
|
||||||
|
# set context
|
||||||
|
if _context:
|
||||||
|
self._context = _context
|
||||||
|
else:
|
||||||
|
if server_side and not certfile:
|
||||||
|
raise ValueError("certfile must be specified for server-side "
|
||||||
|
"operations")
|
||||||
|
|
||||||
|
if keyfile and not certfile:
|
||||||
|
raise ValueError("certfile must be specified")
|
||||||
|
|
||||||
|
if certfile and not keyfile:
|
||||||
|
keyfile = certfile
|
||||||
|
|
||||||
|
self._context = SSLContext(ssl_version, server_side)
|
||||||
|
self._context.verify_mode = cert_reqs
|
||||||
|
if ca_certs:
|
||||||
|
self._context.load_verify_locations(ca_certs)
|
||||||
|
if certfile:
|
||||||
|
self._context.load_cert_chain(certfile, keyfile)
|
||||||
|
if ciphers:
|
||||||
|
self._context.set_ciphers(ciphers)
|
||||||
|
|
||||||
|
self.keyfile = keyfile
|
||||||
|
self.certfile = certfile
|
||||||
|
self.cert_reqs = cert_reqs
|
||||||
|
self.ssl_version = ssl_version
|
||||||
|
self.ca_certs = ca_certs
|
||||||
|
self.ciphers = ciphers
|
||||||
|
|
||||||
|
# preparing socket
|
||||||
|
if sock is not None:
|
||||||
|
# Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
|
||||||
|
# mixed in.
|
||||||
|
if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
|
||||||
|
raise NotImplementedError("only stream sockets are supported")
|
||||||
|
|
||||||
|
socket.__init__(self,
|
||||||
|
family=sock.family,
|
||||||
|
sock_type=sock.type,
|
||||||
|
proto=sock.proto,
|
||||||
|
fileno=sock.fileno())
|
||||||
|
self.settimeout(sock.gettimeout())
|
||||||
|
sock.detach()
|
||||||
|
|
||||||
|
elif fileno is not None:
|
||||||
|
socket.__init__(self, fileno=fileno)
|
||||||
|
|
||||||
|
else:
|
||||||
|
socket.__init__(self, family=family, sock_type=sock_type,
|
||||||
|
proto=proto)
|
||||||
|
|
||||||
|
# See if we are connected
|
||||||
|
try:
|
||||||
|
self.getpeername()
|
||||||
|
except OSError as exception:
|
||||||
|
if exception.errno != errno.ENOTCONN:
|
||||||
|
raise
|
||||||
|
connected = False
|
||||||
|
else:
|
||||||
|
connected = True
|
||||||
|
|
||||||
|
self._closed = False
|
||||||
|
self.native_object = _ffi.NULL
|
||||||
|
self._connected = connected
|
||||||
|
|
||||||
|
if connected:
|
||||||
|
# create the SSL object
|
||||||
|
try:
|
||||||
|
self.native_object = \
|
||||||
|
_lib.wolfSSL_new(self.context.native_object)
|
||||||
|
if self.native_object == _ffi.NULL:
|
||||||
|
raise MemoryError("Unnable to allocate ssl object")
|
||||||
|
|
||||||
|
ret = _lib.wolfSSL_set_fd(self.native_object, self.fileno)
|
||||||
|
if ret != _SSL_SUCCESS:
|
||||||
|
raise ValueError("Unnable to set fd to ssl object")
|
||||||
|
|
||||||
|
if do_handshake_on_connect:
|
||||||
|
self.do_handshake()
|
||||||
|
except (OSError, ValueError):
|
||||||
|
self.close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def context(self):
|
||||||
|
"""
|
||||||
|
Returns the context used by this object.
|
||||||
|
"""
|
||||||
|
return self._context
|
||||||
|
|
||||||
|
|
||||||
|
def dup(self):
|
||||||
|
raise NotImplementedError("Can't dup() %s instances" %
|
||||||
|
self.__class__.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_connected(self):
|
||||||
|
if not self._connected:
|
||||||
|
# getpeername() will raise ENOTCONN if the socket is really
|
||||||
|
# not connected; note that we can be connected even without
|
||||||
|
# _connected being set, e.g. if connect() first returned
|
||||||
|
# EAGAIN.
|
||||||
|
self.getpeername()
|
||||||
|
|
||||||
|
|
||||||
|
def write(self, data):
|
||||||
|
"""
|
||||||
|
Write DATA to the underlying SSL channel. Returns
|
||||||
|
number of bytes of DATA actually transmitted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.native_object == _ffi.NULL:
|
||||||
|
raise ValueError("Write on closed or unwrapped SSL socket")
|
||||||
|
|
||||||
|
data = t2b(data)
|
||||||
|
|
||||||
|
return _lib.wolfSSL_write(self.native_object, data, len(data))
|
||||||
|
|
||||||
|
|
||||||
|
def send(self, data, flags=0):
|
||||||
|
if self.native_object != _ffi.NULL:
|
||||||
|
if flags != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"non-zero flags not allowed in calls to send() on %s" %
|
||||||
|
self.__class__)
|
||||||
|
return self.write(data)
|
||||||
|
else:
|
||||||
|
return socket.send(self, data, flags)
|
||||||
|
|
||||||
|
|
||||||
|
def sendto(self, data, flags_or_addr, addr=None):
|
||||||
|
if self.native_object != _ffi.NULL:
|
||||||
|
raise ValueError("sendto not allowed on instances of %s" %
|
||||||
|
self.__class__)
|
||||||
|
elif addr is None:
|
||||||
|
return socket.sendto(self, data, flags_or_addr)
|
||||||
|
else:
|
||||||
|
return socket.sendto(self, data, flags_or_addr, addr)
|
||||||
|
|
||||||
|
|
||||||
|
def sendmsg(self, *args, **kwargs):
|
||||||
|
# Ensure programs don't send data unencrypted if they try to
|
||||||
|
# use this method.
|
||||||
|
raise NotImplementedError("sendmsg not allowed on instances of %s" %
|
||||||
|
self.__class__)
|
||||||
|
|
||||||
|
|
||||||
|
def sendall(self, data, flags=0):
|
||||||
|
if self.native_object != _ffi.NULL:
|
||||||
|
if flags != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"non-zero flags not allowed in calls to sendall() on %s" %
|
||||||
|
self.__class__)
|
||||||
|
|
||||||
|
amount = len(data)
|
||||||
|
count = 0
|
||||||
|
while count < amount:
|
||||||
|
sent = self.send(data[count:])
|
||||||
|
count += sent
|
||||||
|
return amount
|
||||||
|
else:
|
||||||
|
return socket.sendall(self, data, flags)
|
||||||
|
|
||||||
|
|
||||||
|
def sendfile(self, file, offset=0, count=None):
|
||||||
|
"""
|
||||||
|
Send a file, possibly by using os.sendfile() if this is a
|
||||||
|
clear-text socket. Return the total number of bytes sent.
|
||||||
|
"""
|
||||||
|
# Ensure programs don't send unencrypted files if they try to
|
||||||
|
# use this method.
|
||||||
|
raise NotImplementedError("sendfile not allowed on instances of %s" %
|
||||||
|
self.__class__)
|
||||||
|
|
||||||
|
|
||||||
|
def read(self, length=1024, buffer=None):
|
||||||
|
"""
|
||||||
|
Read up to LEN bytes and return them.
|
||||||
|
Return zero-length string on EOF.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.native_object == _ffi.NULL:
|
||||||
|
raise ValueError("Read on closed or unwrapped SSL socket")
|
||||||
|
|
||||||
|
data = t2b("\0" * length)
|
||||||
|
length = _lib.WolfSSL_read(self.native_object, data, length)
|
||||||
|
|
||||||
|
if buffer is not None:
|
||||||
|
buffer.write(data, length)
|
||||||
|
return length
|
||||||
|
else:
|
||||||
|
raise MemoryError("")
|
||||||
|
|
||||||
|
return self._sslobj.read(len, buffer)
|
||||||
|
except SSLError as exception:
|
||||||
|
if exception.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
||||||
|
if buffer is not None:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return b''
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def recv(self, buflen=1024, flags=0):
|
||||||
|
self._checkClosed()
|
||||||
|
if self._sslobj:
|
||||||
|
if flags != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"non-zero flags not allowed in calls to recv() on %s" %
|
||||||
|
self.__class__)
|
||||||
|
return self.read(buflen)
|
||||||
|
else:
|
||||||
|
return socket.recv(self, buflen, flags)
|
||||||
|
|
||||||
|
|
||||||
|
def recv_into(self, buffer, nbytes=None, flags=0):
|
||||||
|
self._checkClosed()
|
||||||
|
if buffer and (nbytes is None):
|
||||||
|
nbytes = len(buffer)
|
||||||
|
elif nbytes is None:
|
||||||
|
nbytes = 1024
|
||||||
|
if self._sslobj:
|
||||||
|
if flags != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"non-zero flags not allowed in calls to recv_into() on %s"
|
||||||
|
% self.__class__)
|
||||||
|
return self.read(nbytes, buffer)
|
||||||
|
else:
|
||||||
|
return socket.recv_into(self, buffer, nbytes, flags)
|
||||||
|
|
||||||
|
|
||||||
|
def recvfrom(self, buflen=1024, flags=0):
|
||||||
|
self._checkClosed()
|
||||||
|
if self._sslobj:
|
||||||
|
raise ValueError("recvfrom not allowed on instances of %s" %
|
||||||
|
self.__class__)
|
||||||
|
else:
|
||||||
|
return socket.recvfrom(self, buflen, flags)
|
||||||
|
|
||||||
|
|
||||||
|
def recvfrom_into(self, buffer, nbytes=None, flags=0):
|
||||||
|
self._checkClosed()
|
||||||
|
if self._sslobj:
|
||||||
|
raise ValueError("recvfrom_into not allowed on instances of %s" %
|
||||||
|
self.__class__)
|
||||||
|
else:
|
||||||
|
return socket.recvfrom_into(self, buffer, nbytes, flags)
|
||||||
|
|
||||||
|
|
||||||
|
def recvmsg(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("recvmsg not allowed on instances of %s" %
|
||||||
|
self.__class__)
|
||||||
|
|
||||||
|
|
||||||
|
def recvmsg_into(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError("recvmsg_into not allowed on instances of "
|
||||||
|
"%s" % self.__class__)
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown(self, how):
|
||||||
|
self._checkClosed()
|
||||||
|
self._sslobj = None
|
||||||
|
socket.shutdown(self, how)
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap(self):
|
||||||
|
if self._sslobj:
|
||||||
|
s = self._sslobj.unwrap()
|
||||||
|
self._sslobj = None
|
||||||
|
return s
|
||||||
|
else:
|
||||||
|
raise ValueError("No SSL wrapper around " + str(self))
|
||||||
|
|
||||||
|
def _real_close(self):
|
||||||
|
self._sslobj = None
|
||||||
|
socket._real_close(self)
|
||||||
|
|
||||||
|
def do_handshake(self, block=False):
|
||||||
|
"""Perform a TLS/SSL handshake."""
|
||||||
|
self._check_connected()
|
||||||
|
timeout = self.gettimeout()
|
||||||
|
try:
|
||||||
|
if timeout == 0.0 and block:
|
||||||
|
self.settimeout(None)
|
||||||
|
self._sslobj.do_handshake()
|
||||||
|
finally:
|
||||||
|
self.settimeout(timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def _real_connect(self, addr, connect_ex):
|
||||||
|
if self.server_side:
|
||||||
|
raise ValueError("can't connect in server-side mode")
|
||||||
|
# Here we assume that the socket is client-side, and not
|
||||||
|
# connected at the time of the call. We connect it, then wrap it.
|
||||||
|
if self._connected:
|
||||||
|
raise ValueError("attempt to connect already-connected SSLSocket!")
|
||||||
|
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
|
||||||
|
self._sslobj = SSLObject(sslobj, owner=self)
|
||||||
|
try:
|
||||||
|
if connect_ex:
|
||||||
|
rc = socket.connect_ex(self, addr)
|
||||||
|
else:
|
||||||
|
rc = None
|
||||||
|
socket.connect(self, addr)
|
||||||
|
if not rc:
|
||||||
|
self._connected = True
|
||||||
|
if self.do_handshake_on_connect:
|
||||||
|
self.do_handshake()
|
||||||
|
return rc
|
||||||
|
except (OSError, ValueError):
|
||||||
|
self._sslobj = None
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def connect(self, addr):
|
||||||
|
"""Connects to remote ADDR, and then wraps the connection in
|
||||||
|
an SSL channel."""
|
||||||
|
self._real_connect(addr, False)
|
||||||
|
|
||||||
|
|
||||||
|
def connect_ex(self, addr):
|
||||||
|
"""Connects to remote ADDR, and then wraps the connection in
|
||||||
|
an SSL channel."""
|
||||||
|
return self._real_connect(addr, True)
|
||||||
|
|
||||||
|
|
||||||
|
def accept(self):
|
||||||
|
"""Accepts a new connection from a remote client, and returns
|
||||||
|
a tuple containing that new connection wrapped with a server-side
|
||||||
|
SSL channel, and the address of the remote client."""
|
||||||
|
|
||||||
|
newsock, addr = socket.accept(self)
|
||||||
|
newsock = self.context.wrap_socket(
|
||||||
|
newsock,
|
||||||
|
do_handshake_on_connect=self.do_handshake_on_connect,
|
||||||
|
suppress_ragged_eofs=self.suppress_ragged_eofs,
|
||||||
|
server_side=True)
|
||||||
|
return newsock, addr
|
||||||
|
|
||||||
|
|
||||||
def wrap_socket(sock, keyfile=None, certfile=None, server_side=False,
|
def wrap_socket(sock, keyfile=None, certfile=None, server_side=False,
|
||||||
|
Reference in New Issue
Block a user