pretest version of SSLSocket

This commit is contained in:
Moisés Guimarães
2016-12-09 15:15:51 -03:00
parent 567dfd76b3
commit b9934695fb
7 changed files with 167 additions and 179 deletions

View File

@ -20,6 +20,8 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring
METADATA = dict(
__name__="wolfssl",
__version__="0.1.0",

View File

@ -55,6 +55,7 @@ _VERIFY_MODE_LIST = [CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED]
_SSL_SUCCESS = 1
_SSL_FILETYPE_PEM = 1
_SSL_ERROR_WANT_READ = 2
class SSLContext(object):
"""
@ -112,8 +113,7 @@ class SSLContext(object):
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None):
suppress_ragged_eofs=True):
"""
Wrap an existing Python socket sock and return an SSLSocket object.
sock must be a SOCK_STREAM socket; other socket types are unsupported.
@ -126,7 +126,6 @@ class SSLContext(object):
return SSLSocket(sock=sock, server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
server_hostname=server_hostname,
_context=self)
@ -209,7 +208,7 @@ class SSLContext(object):
class SSLSocket(socket):
"""
This class implements a subtype of socket.socket that wraps the
underlying OS socket in an SSL context when necessary, and provides
underlying OS socket in an SSL/TLS connection, providing secure
read and write methods over that channel.
"""
@ -278,7 +277,7 @@ class SSLSocket(socket):
socket.__init__(self, family=family, sock_type=sock_type,
proto=proto)
# See if we are connected
# see if we are connected
try:
self.getpeername()
except OSError as exception:
@ -289,28 +288,38 @@ class SSLSocket(socket):
connected = True
self._closed = False
self.native_object = _ffi.NULL
self._connected = connected
# create the SSL object
self.native_object = _lib.wolfSSL_new(self.context.native_object)
if self.native_object == _ffi.NULL:
raise MemoryError("Unnable to allocate ssl object")
ret = _lib.wolfSSL_set_fd(self.native_object, self.fileno())
if ret != _SSL_SUCCESS:
self._release_native_object()
raise ValueError("Unnable to set fd to ssl object")
if connected:
# create the SSL object
try:
self.native_object = \
_lib.wolfSSL_new(self.context.native_object)
if self.native_object == _ffi.NULL:
raise MemoryError("Unnable to allocate ssl object")
ret = _lib.wolfSSL_set_fd(self.native_object, self.fileno)
if ret != _SSL_SUCCESS:
raise ValueError("Unnable to set fd to ssl object")
if do_handshake_on_connect:
self.do_handshake()
except (OSError, ValueError):
self.close()
except:
self._release_native_object()
self._socket.close()
raise
def __del__(self):
self._release_native_object()
def _release_native_object(self):
if self.native_object != _ffi.NULL:
_lib.wolfSSL_CTX_free(self.native_object)
self.native_object = _ffi.NULL
@property
def context(self):
"""
@ -324,6 +333,10 @@ class SSLSocket(socket):
self.__class__.__name__)
def _check_closed(self, call=None):
if self.native_object == _ffi.NULL:
raise ValueError("%s on closed or unwrapped secure channel" % call)
def _check_connected(self):
if not self._connected:
# getpeername() will raise ENOTCONN if the socket is really
@ -335,12 +348,11 @@ class SSLSocket(socket):
def write(self, data):
"""
Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted.
Write DATA to the underlying secure channel.
Returns number of bytes of DATA actually transmitted.
"""
if self.native_object == _ffi.NULL:
raise ValueError("Write on closed or unwrapped SSL socket")
self._check_closed("write")
self._check_connected()
data = t2b(data)
@ -348,134 +360,93 @@ class SSLSocket(socket):
def send(self, data, flags=0):
if self.native_object != _ffi.NULL:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
return self.write(data)
else:
return socket.send(self, data, flags)
if flags != 0:
raise NotImplementedError("non-zero flags not allowed in calls to "
"send() on %s" % self.__class__)
def sendto(self, data, flags_or_addr, addr=None):
if self.native_object != _ffi.NULL:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
elif addr is None:
return socket.sendto(self, data, flags_or_addr)
else:
return socket.sendto(self, data, flags_or_addr, addr)
def sendmsg(self, *args, **kwargs):
# Ensure programs don't send data unencrypted if they try to
# use this method.
raise NotImplementedError("sendmsg not allowed on instances of %s" %
self.__class__)
return self.write(data)
def sendall(self, data, flags=0):
if self.native_object != _ffi.NULL:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
if flags != 0:
raise NotImplementedError("non-zero flags not allowed in calls to "
"sendall() on %s" % self.__class__)
amount = len(data)
count = 0
while count < amount:
sent = self.send(data[count:])
count += sent
return amount
else:
return socket.sendall(self, data, flags)
length = len(data)
sent = 0
while sent < length:
sent += self.write(data[sent:])
return sent
def sendto(self, data, flags_or_addr, addr=None):
# Ensure programs don't send unencrypted data trying to use this method
raise NotImplementedError("sendto not allowed on instances "
"of %s" % self.__class__)
def sendmsg(self, *args, **kwargs):
# Ensure programs don't send unencrypted data trying to use this method
raise NotImplementedError("sendmsg not allowed on instances "
"of %s" % self.__class__)
def sendfile(self, file, offset=0, count=None):
"""
Send a file, possibly by using os.sendfile() if this is a
clear-text socket. Return the total number of bytes sent.
"""
# Ensure programs don't send unencrypted files if they try to
# use this method.
raise NotImplementedError("sendfile not allowed on instances of %s" %
self.__class__)
# Ensure programs don't send unencrypted files trying to use this method
raise NotImplementedError("sendfile not allowed on instances "
"of %s" % self.__class__)
def read(self, length=1024, buffer=None):
"""
Read up to LEN bytes and return them.
Read up to LENGTH bytes and return them.
Return zero-length string on EOF.
"""
self._check_closed("read")
self._check_connected()
if buffer is not None:
raise ValueError("buffer not allowed in calls to "
"read() on %s" % self.__class__)
if self.native_object == _ffi.NULL:
raise ValueError("Read on closed or unwrapped SSL socket")
data = t2b("\0" * length)
length = _lib.WolfSSL_read(self.native_object, data, length)
if buffer is not None:
buffer.write(data, length)
return length
else:
raise MemoryError("")
return self._sslobj.read(len, buffer)
except SSLError as exception:
if exception.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None:
return 0
else:
return b''
if length < 0:
err = _lib.wolfSSL_get_error(self.native_object, 0)
if err == _SSL_ERROR_WANT_READ:
raise SSLWantReadError()
else:
raise
raise SSLError("wolfSSL_read error (%d)" % err)
return data[:length] if length > 0 else b''
def recv(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
self.__class__)
return self.read(buflen)
else:
return socket.recv(self, buflen, flags)
def recv(self, length=1024, flags=0):
if flags != 0:
raise NotImplementedError("non-zero flags not allowed in calls to "
"recv() on %s" % self.__class__)
return self.read(self, length)
def recv_into(self, buffer, nbytes=None, flags=0):
self._checkClosed()
if buffer and (nbytes is None):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s"
% self.__class__)
return self.read(nbytes, buffer)
else:
return socket.recv_into(self, buffer, nbytes, flags)
raise NotImplementedError("recv_into not allowed on instances "
"of %s" % self.__class__)
def recvfrom(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
else:
return socket.recvfrom(self, buflen, flags)
def recvfrom(self, length=1024, flags=0):
# Ensure programs don't receive encrypted data trying to use this method
raise NotImplementedError("recvfrom not allowed on instances "
"of %s" % self.__class__)
def recvfrom_into(self, buffer, nbytes=None, flags=0):
self._checkClosed()
if self._sslobj:
raise ValueError("recvfrom_into not allowed on instances of %s" %
self.__class__)
else:
return socket.recvfrom_into(self, buffer, nbytes, flags)
# Ensure programs don't receive encrypted data trying to use this method
raise NotImplementedError("recvfrom_into not allowed on instances "
"of %s" % self.__class__)
def recvmsg(self, *args, **kwargs):
@ -489,84 +460,88 @@ class SSLSocket(socket):
def shutdown(self, how):
self._checkClosed()
self._sslobj = None
if self.native_object != _ffi.NULL:
_lib.wolfSSL_shutdown(self.native_object)
self._release_native_object()
socket.shutdown(self, how)
def unwrap(self):
if self._sslobj:
s = self._sslobj.unwrap()
self._sslobj = None
return s
else:
raise ValueError("No SSL wrapper around " + str(self))
"""
Unwraps the underlying OS socket from the SSL/TLS connection.
Returns the wrapped OS socket.
"""
if self.native_object != _ffi.NULL:
_lib.wolfSSL_set_fd(self.native_object, -1)
sock = socket(family=self.family,
sock_type=self.type,
proto=self.proto,
fileno=self.fileno())
sock.settimeout(self.gettimeout())
self.detach()
return sock
def _real_close(self):
self._sslobj = None
socket._real_close(self)
def do_handshake(self, block=False):
"""Perform a TLS/SSL handshake."""
"""
Perform a TLS/SSL handshake.
"""
self._check_closed("do_handshake")
self._check_connected()
timeout = self.gettimeout()
try:
if timeout == 0.0 and block:
self.settimeout(None)
self._sslobj.do_handshake()
finally:
self.settimeout(timeout)
ret = _lib.wolfSSL_negotiate(self.native_object)
if ret != _SSL_SUCCESS:
raise SSLError("do_handshake failed with error %d" % ret)
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
self._sslobj = SSLObject(sslobj, owner=self)
try:
if connect_ex:
rc = socket.connect_ex(self, addr)
else:
rc = None
socket.connect(self, addr)
if not rc:
self._connected = True
if self.do_handshake_on_connect:
self.do_handshake()
return rc
except (OSError, ValueError):
self._sslobj = None
raise
if connect_ex:
err = self._socket.connect_ex(addr)
else:
err = 0
self._socket.connect(addr)
if err == 0:
self._connected = True
if self.do_handshake_on_connect:
self.do_handshake()
return err
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
"""
Connects to remote ADDR, and then wraps the connection in a secure
channel.
"""
self._real_connect(addr, False)
def connect_ex(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
"""
Connects to remote ADDR, and then wraps the connection in a secure
channel.
"""
return self._real_connect(addr, True)
def accept(self):
"""Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
newsock = self.context.wrap_socket(
newsock,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs,
server_side=True)
return newsock, addr
"""
Accepts a new connection from a remote client, and returns a tuple
containing that new connection wrapped with a server-side secure
channel, and the address of the remote client.
"""
pass
def wrap_socket(sock, keyfile=None, certfile=None, server_side=False,

View File

@ -20,6 +20,8 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring
from socket import error as socket_error

View File

@ -19,9 +19,12 @@
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring
try:
from wolfssl._ffi import ffi as _ffi
from wolfssl._ffi import lib as _lib
from wolfssl._ffi import ffi as _ffi
from wolfssl._ffi import lib as _lib
except ImportError:
pass

View File

@ -19,9 +19,12 @@
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring, invalid-name
try:
from wolfssl._ffi import ffi as _ffi
from wolfssl._ffi import lib as _lib
from wolfssl._ffi import ffi as _ffi
from wolfssl._ffi import lib as _lib
except ImportError:
pass
@ -66,8 +69,8 @@ class WolfSSLMethod(object):
_lib.wolfTLSv1_2_client_method()
elif protocol in [PROTOCOL_SSLv23, PROTOCOL_TLS]:
self.native_object = \
_lib.wolfSSLv23_server_method() if server_side else \
self.native_object = \
_lib.wolfSSLv23_server_method() if server_side else \
_lib.wolfSSLv23_client_method()
if self.native_object == _ffi.NULL:

View File

@ -20,6 +20,8 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=missing-docstring, invalid-name
from cffi import FFI
ffi = FFI()
@ -63,6 +65,7 @@ ffi.cdef(
void wolfSSL_free(void*);
int wolfSSL_set_fd(void*, int);
int wolfSSL_get_error(void*, int);
int wolfSSL_negotiate(void*);
int wolfSSL_write(void*, const void*, int);
int wolfSSL_read(void*, void*, int);

View File

@ -20,7 +20,7 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
# pylint: disable=unused-import, undefined-variable
# pylint: disable=missing-docstring, unused-import, undefined-variable
import sys
from binascii import hexlify as b2h, unhexlify as h2b