pretest version of SSLSocket

This commit is contained in:
Moisés Guimarães
2016-12-09 15:15:51 -03:00
parent 567dfd76b3
commit b9934695fb
7 changed files with 167 additions and 179 deletions

View File

@ -20,6 +20,8 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring
METADATA = dict( METADATA = dict(
__name__="wolfssl", __name__="wolfssl",
__version__="0.1.0", __version__="0.1.0",

View File

@ -55,6 +55,7 @@ _VERIFY_MODE_LIST = [CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED]
_SSL_SUCCESS = 1 _SSL_SUCCESS = 1
_SSL_FILETYPE_PEM = 1 _SSL_FILETYPE_PEM = 1
_SSL_ERROR_WANT_READ = 2
class SSLContext(object): class SSLContext(object):
""" """
@ -112,8 +113,7 @@ class SSLContext(object):
def wrap_socket(self, sock, server_side=False, def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True, do_handshake_on_connect=True,
suppress_ragged_eofs=True, suppress_ragged_eofs=True):
server_hostname=None):
""" """
Wrap an existing Python socket sock and return an SSLSocket object. Wrap an existing Python socket sock and return an SSLSocket object.
sock must be a SOCK_STREAM socket; other socket types are unsupported. sock must be a SOCK_STREAM socket; other socket types are unsupported.
@ -126,7 +126,6 @@ class SSLContext(object):
return SSLSocket(sock=sock, server_side=server_side, return SSLSocket(sock=sock, server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect, do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs, suppress_ragged_eofs=suppress_ragged_eofs,
server_hostname=server_hostname,
_context=self) _context=self)
@ -209,7 +208,7 @@ class SSLContext(object):
class SSLSocket(socket): class SSLSocket(socket):
""" """
This class implements a subtype of socket.socket that wraps the This class implements a subtype of socket.socket that wraps the
underlying OS socket in an SSL context when necessary, and provides underlying OS socket in an SSL/TLS connection, providing secure
read and write methods over that channel. read and write methods over that channel.
""" """
@ -278,7 +277,7 @@ class SSLSocket(socket):
socket.__init__(self, family=family, sock_type=sock_type, socket.__init__(self, family=family, sock_type=sock_type,
proto=proto) proto=proto)
# See if we are connected # see if we are connected
try: try:
self.getpeername() self.getpeername()
except OSError as exception: except OSError as exception:
@ -289,28 +288,38 @@ class SSLSocket(socket):
connected = True connected = True
self._closed = False self._closed = False
self.native_object = _ffi.NULL
self._connected = connected self._connected = connected
# create the SSL object
self.native_object = _lib.wolfSSL_new(self.context.native_object)
if self.native_object == _ffi.NULL:
raise MemoryError("Unnable to allocate ssl object")
ret = _lib.wolfSSL_set_fd(self.native_object, self.fileno())
if ret != _SSL_SUCCESS:
self._release_native_object()
raise ValueError("Unnable to set fd to ssl object")
if connected: if connected:
# create the SSL object
try: try:
self.native_object = \
_lib.wolfSSL_new(self.context.native_object)
if self.native_object == _ffi.NULL:
raise MemoryError("Unnable to allocate ssl object")
ret = _lib.wolfSSL_set_fd(self.native_object, self.fileno)
if ret != _SSL_SUCCESS:
raise ValueError("Unnable to set fd to ssl object")
if do_handshake_on_connect: if do_handshake_on_connect:
self.do_handshake() self.do_handshake()
except (OSError, ValueError): except:
self.close() self._release_native_object()
self._socket.close()
raise raise
def __del__(self):
self._release_native_object()
def _release_native_object(self):
if self.native_object != _ffi.NULL:
_lib.wolfSSL_CTX_free(self.native_object)
self.native_object = _ffi.NULL
@property @property
def context(self): def context(self):
""" """
@ -324,6 +333,10 @@ class SSLSocket(socket):
self.__class__.__name__) self.__class__.__name__)
def _check_closed(self, call=None):
if self.native_object == _ffi.NULL:
raise ValueError("%s on closed or unwrapped secure channel" % call)
def _check_connected(self): def _check_connected(self):
if not self._connected: if not self._connected:
# getpeername() will raise ENOTCONN if the socket is really # getpeername() will raise ENOTCONN if the socket is really
@ -335,12 +348,11 @@ class SSLSocket(socket):
def write(self, data): def write(self, data):
""" """
Write DATA to the underlying SSL channel. Returns Write DATA to the underlying secure channel.
number of bytes of DATA actually transmitted. Returns number of bytes of DATA actually transmitted.
""" """
self._check_closed("write")
if self.native_object == _ffi.NULL: self._check_connected()
raise ValueError("Write on closed or unwrapped SSL socket")
data = t2b(data) data = t2b(data)
@ -348,134 +360,93 @@ class SSLSocket(socket):
def send(self, data, flags=0): def send(self, data, flags=0):
if self.native_object != _ffi.NULL: if flags != 0:
if flags != 0: raise NotImplementedError("non-zero flags not allowed in calls to "
raise ValueError( "send() on %s" % self.__class__)
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
return self.write(data)
else:
return socket.send(self, data, flags)
return self.write(data)
def sendto(self, data, flags_or_addr, addr=None):
if self.native_object != _ffi.NULL:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
elif addr is None:
return socket.sendto(self, data, flags_or_addr)
else:
return socket.sendto(self, data, flags_or_addr, addr)
def sendmsg(self, *args, **kwargs):
# Ensure programs don't send data unencrypted if they try to
# use this method.
raise NotImplementedError("sendmsg not allowed on instances of %s" %
self.__class__)
def sendall(self, data, flags=0): def sendall(self, data, flags=0):
if self.native_object != _ffi.NULL: if flags != 0:
if flags != 0: raise NotImplementedError("non-zero flags not allowed in calls to "
raise ValueError( "sendall() on %s" % self.__class__)
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
amount = len(data) length = len(data)
count = 0 sent = 0
while count < amount:
sent = self.send(data[count:]) while sent < length:
count += sent sent += self.write(data[sent:])
return amount
else: return sent
return socket.sendall(self, data, flags)
def sendto(self, data, flags_or_addr, addr=None):
# Ensure programs don't send unencrypted data trying to use this method
raise NotImplementedError("sendto not allowed on instances "
"of %s" % self.__class__)
def sendmsg(self, *args, **kwargs):
# Ensure programs don't send unencrypted data trying to use this method
raise NotImplementedError("sendmsg not allowed on instances "
"of %s" % self.__class__)
def sendfile(self, file, offset=0, count=None): def sendfile(self, file, offset=0, count=None):
""" # Ensure programs don't send unencrypted files trying to use this method
Send a file, possibly by using os.sendfile() if this is a raise NotImplementedError("sendfile not allowed on instances "
clear-text socket. Return the total number of bytes sent. "of %s" % self.__class__)
"""
# Ensure programs don't send unencrypted files if they try to
# use this method.
raise NotImplementedError("sendfile not allowed on instances of %s" %
self.__class__)
def read(self, length=1024, buffer=None): def read(self, length=1024, buffer=None):
""" """
Read up to LEN bytes and return them. Read up to LENGTH bytes and return them.
Return zero-length string on EOF. Return zero-length string on EOF.
""" """
self._check_closed("read")
self._check_connected()
if buffer is not None:
raise ValueError("buffer not allowed in calls to "
"read() on %s" % self.__class__)
if self.native_object == _ffi.NULL:
raise ValueError("Read on closed or unwrapped SSL socket")
data = t2b("\0" * length) data = t2b("\0" * length)
length = _lib.WolfSSL_read(self.native_object, data, length) length = _lib.WolfSSL_read(self.native_object, data, length)
if buffer is not None: if length < 0:
buffer.write(data, length) err = _lib.wolfSSL_get_error(self.native_object, 0)
return length if err == _SSL_ERROR_WANT_READ:
else: raise SSLWantReadError()
raise MemoryError("")
return self._sslobj.read(len, buffer)
except SSLError as exception:
if exception.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None:
return 0
else:
return b''
else: else:
raise raise SSLError("wolfSSL_read error (%d)" % err)
return data[:length] if length > 0 else b''
def recv(self, buflen=1024, flags=0): def recv(self, length=1024, flags=0):
self._checkClosed() if flags != 0:
if self._sslobj: raise NotImplementedError("non-zero flags not allowed in calls to "
if flags != 0: "recv() on %s" % self.__class__)
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" % return self.read(self, length)
self.__class__)
return self.read(buflen)
else:
return socket.recv(self, buflen, flags)
def recv_into(self, buffer, nbytes=None, flags=0): def recv_into(self, buffer, nbytes=None, flags=0):
self._checkClosed() raise NotImplementedError("recv_into not allowed on instances "
if buffer and (nbytes is None): "of %s" % self.__class__)
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s"
% self.__class__)
return self.read(nbytes, buffer)
else:
return socket.recv_into(self, buffer, nbytes, flags)
def recvfrom(self, buflen=1024, flags=0): def recvfrom(self, length=1024, flags=0):
self._checkClosed() # Ensure programs don't receive encrypted data trying to use this method
if self._sslobj: raise NotImplementedError("recvfrom not allowed on instances "
raise ValueError("recvfrom not allowed on instances of %s" % "of %s" % self.__class__)
self.__class__)
else:
return socket.recvfrom(self, buflen, flags)
def recvfrom_into(self, buffer, nbytes=None, flags=0): def recvfrom_into(self, buffer, nbytes=None, flags=0):
self._checkClosed() # Ensure programs don't receive encrypted data trying to use this method
if self._sslobj: raise NotImplementedError("recvfrom_into not allowed on instances "
raise ValueError("recvfrom_into not allowed on instances of %s" % "of %s" % self.__class__)
self.__class__)
else:
return socket.recvfrom_into(self, buffer, nbytes, flags)
def recvmsg(self, *args, **kwargs): def recvmsg(self, *args, **kwargs):
@ -489,84 +460,88 @@ class SSLSocket(socket):
def shutdown(self, how): def shutdown(self, how):
self._checkClosed() if self.native_object != _ffi.NULL:
self._sslobj = None _lib.wolfSSL_shutdown(self.native_object)
self._release_native_object()
socket.shutdown(self, how) socket.shutdown(self, how)
def unwrap(self): def unwrap(self):
if self._sslobj: """
s = self._sslobj.unwrap() Unwraps the underlying OS socket from the SSL/TLS connection.
self._sslobj = None Returns the wrapped OS socket.
return s """
else: if self.native_object != _ffi.NULL:
raise ValueError("No SSL wrapper around " + str(self)) _lib.wolfSSL_set_fd(self.native_object, -1)
sock = socket(family=self.family,
sock_type=self.type,
proto=self.proto,
fileno=self.fileno())
sock.settimeout(self.gettimeout())
self.detach()
return sock
def _real_close(self):
self._sslobj = None
socket._real_close(self)
def do_handshake(self, block=False): def do_handshake(self, block=False):
"""Perform a TLS/SSL handshake.""" """
Perform a TLS/SSL handshake.
"""
self._check_closed("do_handshake")
self._check_connected() self._check_connected()
timeout = self.gettimeout()
try: ret = _lib.wolfSSL_negotiate(self.native_object)
if timeout == 0.0 and block: if ret != _SSL_SUCCESS:
self.settimeout(None) raise SSLError("do_handshake failed with error %d" % ret)
self._sslobj.do_handshake()
finally:
self.settimeout(timeout)
def _real_connect(self, addr, connect_ex): def _real_connect(self, addr, connect_ex):
if self.server_side: if self.server_side:
raise ValueError("can't connect in server-side mode") raise ValueError("can't connect in server-side mode")
# Here we assume that the socket is client-side, and not # Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it. # connected at the time of the call. We connect it, then wrap it.
if self._connected: if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!") raise ValueError("attempt to connect already-connected SSLSocket!")
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self) if connect_ex:
try: err = self._socket.connect_ex(addr)
if connect_ex: else:
rc = socket.connect_ex(self, addr) err = 0
else: self._socket.connect(addr)
rc = None
socket.connect(self, addr) if err == 0:
if not rc: self._connected = True
self._connected = True if self.do_handshake_on_connect:
if self.do_handshake_on_connect: self.do_handshake()
self.do_handshake()
return rc return err
except (OSError, ValueError):
self._sslobj = None
raise
def connect(self, addr): def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in """
an SSL channel.""" Connects to remote ADDR, and then wraps the connection in a secure
channel.
"""
self._real_connect(addr, False) self._real_connect(addr, False)
def connect_ex(self, addr): def connect_ex(self, addr):
"""Connects to remote ADDR, and then wraps the connection in """
an SSL channel.""" Connects to remote ADDR, and then wraps the connection in a secure
channel.
"""
return self._real_connect(addr, True) return self._real_connect(addr, True)
def accept(self): def accept(self):
"""Accepts a new connection from a remote client, and returns """
a tuple containing that new connection wrapped with a server-side Accepts a new connection from a remote client, and returns a tuple
SSL channel, and the address of the remote client.""" containing that new connection wrapped with a server-side secure
channel, and the address of the remote client.
newsock, addr = socket.accept(self) """
newsock = self.context.wrap_socket( pass
newsock,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs,
server_side=True)
return newsock, addr
def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, def wrap_socket(sock, keyfile=None, certfile=None, server_side=False,

View File

@ -20,6 +20,8 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring
from socket import error as socket_error from socket import error as socket_error

View File

@ -19,9 +19,12 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring
try: try:
from wolfssl._ffi import ffi as _ffi from wolfssl._ffi import ffi as _ffi
from wolfssl._ffi import lib as _lib from wolfssl._ffi import lib as _lib
except ImportError: except ImportError:
pass pass

View File

@ -19,9 +19,12 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring, invalid-name
try: try:
from wolfssl._ffi import ffi as _ffi from wolfssl._ffi import ffi as _ffi
from wolfssl._ffi import lib as _lib from wolfssl._ffi import lib as _lib
except ImportError: except ImportError:
pass pass
@ -66,8 +69,8 @@ class WolfSSLMethod(object):
_lib.wolfTLSv1_2_client_method() _lib.wolfTLSv1_2_client_method()
elif protocol in [PROTOCOL_SSLv23, PROTOCOL_TLS]: elif protocol in [PROTOCOL_SSLv23, PROTOCOL_TLS]:
self.native_object = \ self.native_object = \
_lib.wolfSSLv23_server_method() if server_side else \ _lib.wolfSSLv23_server_method() if server_side else \
_lib.wolfSSLv23_client_method() _lib.wolfSSLv23_client_method()
if self.native_object == _ffi.NULL: if self.native_object == _ffi.NULL:

View File

@ -20,6 +20,8 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring, invalid-name
from cffi import FFI from cffi import FFI
ffi = FFI() ffi = FFI()
@ -63,6 +65,7 @@ ffi.cdef(
void wolfSSL_free(void*); void wolfSSL_free(void*);
int wolfSSL_set_fd(void*, int); int wolfSSL_set_fd(void*, int);
int wolfSSL_get_error(void*, int);
int wolfSSL_negotiate(void*); int wolfSSL_negotiate(void*);
int wolfSSL_write(void*, const void*, int); int wolfSSL_write(void*, const void*, int);
int wolfSSL_read(void*, void*, int); int wolfSSL_read(void*, void*, int);

View File

@ -20,7 +20,7 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=unused-import, undefined-variable # pylint: disable=missing-docstring, unused-import, undefined-variable
import sys import sys
from binascii import hexlify as b2h, unhexlify as h2b from binascii import hexlify as b2h, unhexlify as h2b