THRIFT-4904: Fix python unit test errors and exception escapes

Due to the way SSL layers on top of sockets, it was possible
to complete a connection and then have the server close it.
This would happen if the client is not checking certificates
but the server is.  The TSSLSocket unit test was enhanced to
do a read and a write as well as just connecting to ensure a
more complete test.

The TSocket read() and write() calls were leaking OSError,
socker.error, and ssl.Error exceptions.  These cases are now
wrapped into a TTransportException of the appropriate type,
and the original exception is added as an argument named inner.
This commit is contained in:
James E. King III 2019-07-02 14:21:05 -04:00
parent 93ff9b0053
commit 3131fe975c
4 changed files with 29 additions and 17 deletions

View File

@ -291,11 +291,11 @@ class TSSLSocket(TSocket.TSocket, TSSLBase):
plain_sock = socket.socket(family, socktype)
try:
return self._wrap_socket(plain_sock)
except Exception:
except Exception as ex:
plain_sock.close()
msg = 'failed to initialize SSL'
logger.exception(msg)
raise TTransportException(TTransportException.NOT_OPEN, msg)
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex)
def open(self):
super(TSSLSocket, self).open()
@ -307,7 +307,7 @@ class TSSLSocket(TSocket.TSocket, TSSLBase):
except TTransportException:
raise
except Exception as ex:
raise TTransportException(TTransportException.UNKNOWN, str(ex))
raise TTransportException(message=str(ex), inner=ex)
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):

View File

@ -94,13 +94,13 @@ class TSocket(TSocketBase):
def open(self):
if self.handle:
raise TTransportException(TTransportException.ALREADY_OPEN)
raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open")
try:
addrs = self._resolveAddr()
except socket.gaierror:
except socket.gaierror as gai:
msg = 'failed to resolve sockaddr for ' + str(self._address)
logger.exception(msg)
raise TTransportException(TTransportException.NOT_OPEN, msg)
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai)
for family, socktype, _, _, sockaddr in addrs:
handle = self._do_open(family, socktype)
@ -119,7 +119,7 @@ class TSocket(TSocketBase):
msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
addrs))
logger.error(msg)
raise TTransportException(TTransportException.NOT_OPEN, msg)
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg)
def read(self, sz):
try:
@ -134,8 +134,10 @@ class TSocket(TSocketBase):
self.close()
# Trigger the check to raise the END_OF_FILE exception below.
buff = ''
elif e.args[0] == errno.ETIMEDOUT:
raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e)
else:
raise
raise TTransportException(message="unexpected exception", inner=e)
if len(buff) == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket read 0 bytes')
@ -148,12 +150,15 @@ class TSocket(TSocketBase):
sent = 0
have = len(buff)
while sent < have:
plus = self.handle.send(buff)
if plus == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket sent 0 bytes')
sent += plus
buff = buff[plus:]
try:
plus = self.handle.send(buff)
if plus == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket sent 0 bytes')
sent += plus
buff = buff[plus:]
except socket.error as e:
raise TTransportException(message="unexpected exception", inner=e)
def flush(self):
pass

View File

@ -34,9 +34,10 @@ class TTransportException(TException):
SIZE_LIMIT = 6
INVALID_CLIENT_TYPE = 7
def __init__(self, type=UNKNOWN, message=None):
def __init__(self, type=UNKNOWN, message=None, inner=None):
TException.__init__(self, message)
self.type = type
self.inner = inner
class TTransportBase(object):

View File

@ -75,6 +75,9 @@ class ServerAcceptor(threading.Thread):
try:
self._client = self._server.accept()
if self._client:
self._client.read(5) # hello
self._client.write(b"there")
except Exception:
logging.exception('error on server side (%s):' % self.name)
if not self._expect_failure:
@ -141,7 +144,8 @@ class TSSLSocketTest(unittest.TestCase):
client.setTimeout(20)
with self._assert_raises(TTransportException):
client.open()
self.assertTrue(acc.client is None)
client.write(b"hello")
client.read(5) # b"there"
finally:
logging.disable(logging.NOTSET)
@ -153,8 +157,10 @@ class TSSLSocketTest(unittest.TestCase):
def _assert_connection_success(self, server, path=None, **client_args):
with self._connectable_client(server, path=path, **client_args) as (acc, client):
client.open()
try:
client.open()
client.write(b"hello")
self.assertEqual(client.read(5), b"there")
self.assertTrue(acc.client is not None)
finally:
client.close()