mirror of
https://github.com/valitydev/thrift.git
synced 2024-11-07 02:45:22 +00:00
THRIFT-67. python: Add TNonblockingServer
This TNonblockingServer is very similar to the C++ implementation. It assumes the framed transport, but it uses select instead of libevent. git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@712306 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
762e9bd15f
commit
74421273ad
291
lib/py/src/server/TNonblockingServer.py
Normal file
291
lib/py/src/server/TNonblockingServer.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""Implementation of non-blocking server.
|
||||
|
||||
The main idea of the server is reciving and sending requests
|
||||
only from main thread.
|
||||
|
||||
It also makes thread pool server in tasks terms, not connections.
|
||||
"""
|
||||
import threading
|
||||
import socket
|
||||
import Queue
|
||||
import select
|
||||
import struct
|
||||
import logging
|
||||
|
||||
from thrift.transport import TTransport
|
||||
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
|
||||
|
||||
__all__ = ['TNonblockingServer']
|
||||
|
||||
class Worker(threading.Thread):
|
||||
"""Worker is a small helper to process incoming connection."""
|
||||
def __init__(self, queue):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue
|
||||
|
||||
def run(self):
|
||||
"""Process queries from task queue, stop if processor is None."""
|
||||
while True:
|
||||
try:
|
||||
processor, iprot, oprot, otrans, callback = self.queue.get()
|
||||
if processor is None:
|
||||
break
|
||||
processor.process(iprot, oprot)
|
||||
callback(True, otrans.getvalue())
|
||||
except Exception:
|
||||
logging.exception("Exception while processing request")
|
||||
callback(False, '')
|
||||
|
||||
WAIT_LEN = 0
|
||||
WAIT_MESSAGE = 1
|
||||
WAIT_PROCESS = 2
|
||||
SEND_ANSWER = 3
|
||||
CLOSED = 4
|
||||
|
||||
def locked(func):
|
||||
"Decorator which locks self.lock."
|
||||
def nested(self, *args, **kwargs):
|
||||
self.lock.acquire()
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
self.lock.release()
|
||||
return nested
|
||||
|
||||
def socket_exception(func):
|
||||
"Decorator close object on socket.error."
|
||||
def read(self, *args, **kwargs):
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except socket.error:
|
||||
self.close()
|
||||
return read
|
||||
|
||||
class Connection:
|
||||
"""Basic class is represented connection.
|
||||
|
||||
It can be in state:
|
||||
WAIT_LEN --- connection is reading request len.
|
||||
WAIT_MESSAGE --- connection is reading request.
|
||||
WAIT_PROCESS --- connection has just read whole request and
|
||||
waits for call ready routine.
|
||||
SEND_ANSWER --- connection is sending answer string (including length
|
||||
of answer).
|
||||
CLOSED --- socket was closed and connection should be deleted.
|
||||
"""
|
||||
def __init__(self, new_socket, wake_up):
|
||||
self.socket = new_socket
|
||||
self.socket.setblocking(False)
|
||||
self.status = WAIT_LEN
|
||||
self.len = 0
|
||||
self.message = ''
|
||||
self.lock = threading.Lock()
|
||||
self.wake_up = wake_up
|
||||
|
||||
def _read_len(self):
|
||||
"""Reads length of request.
|
||||
|
||||
It's really paranoic routine and it may be replaced by
|
||||
self.socket.recv(4)."""
|
||||
read = self.socket.recv(4 - len(self.message))
|
||||
if len(read) == 0:
|
||||
# if we read 0 bytes and self.message is empty, it means client close
|
||||
# connection
|
||||
if len(self.message) != 0:
|
||||
logging.error("can't read frame size from socket")
|
||||
self.close()
|
||||
return
|
||||
self.message += read
|
||||
if len(self.message) == 4:
|
||||
self.len, = struct.unpack('!i', self.message)
|
||||
if self.len < 0:
|
||||
logging.error("negative frame size, it seems client"\
|
||||
" doesn't use FramedTransport")
|
||||
self.close()
|
||||
elif self.len == 0:
|
||||
logging.error("empty frame, it's really strange")
|
||||
self.close()
|
||||
else:
|
||||
self.message = ''
|
||||
self.status = WAIT_MESSAGE
|
||||
|
||||
@socket_exception
|
||||
def read(self):
|
||||
"""Reads data from stream and switch state."""
|
||||
assert self.status in (WAIT_LEN, WAIT_MESSAGE)
|
||||
if self.status == WAIT_LEN:
|
||||
self._read_len()
|
||||
# go back to the main loop here for simplicity instead of
|
||||
# falling through, even though there is a good chance that
|
||||
# the message is already available
|
||||
elif self.status == WAIT_MESSAGE:
|
||||
read = self.socket.recv(self.len - len(self.message))
|
||||
if len(read) == 0:
|
||||
logging.error("can't read frame from socket (get %d of %d bytes)" %
|
||||
(len(self.message), self.len))
|
||||
self.close()
|
||||
return
|
||||
self.message += read
|
||||
if len(self.message) == self.len:
|
||||
self.status = WAIT_PROCESS
|
||||
|
||||
@socket_exception
|
||||
def write(self):
|
||||
"""Writes data from socket and switch state."""
|
||||
assert self.status == SEND_ANSWER
|
||||
sent = self.socket.send(self.message)
|
||||
if sent == len(self.message):
|
||||
self.status = WAIT_LEN
|
||||
self.message = ''
|
||||
self.len = 0
|
||||
else:
|
||||
self.message = self.message[sent:]
|
||||
|
||||
@locked
|
||||
def ready(self, all_ok, message):
|
||||
"""Callback function for switching state and waking up main thread.
|
||||
|
||||
This function is the only function witch can be called asynchronous.
|
||||
|
||||
The ready can switch Connection to three states:
|
||||
WAIT_LEN if request was async.
|
||||
SEND_ANSWER if request was processed in normal way.
|
||||
CLOSED if request throws unexpected exception.
|
||||
|
||||
The one wakes up main thread.
|
||||
"""
|
||||
assert self.status == WAIT_PROCESS
|
||||
if not all_ok:
|
||||
self.close()
|
||||
self.wake_up()
|
||||
return
|
||||
self.len = ''
|
||||
self.message = struct.pack('!i', len(message)) + message
|
||||
if len(message) == 0:
|
||||
# it was async request, do not write answer
|
||||
self.status = WAIT_LEN
|
||||
else:
|
||||
self.status = SEND_ANSWER
|
||||
self.wake_up()
|
||||
|
||||
@locked
|
||||
def is_writeable(self):
|
||||
"Returns True if connection should be added to write list of select."
|
||||
return self.status == SEND_ANSWER
|
||||
|
||||
# it's not necessary, but...
|
||||
@locked
|
||||
def is_readable(self):
|
||||
"Returns True if connection should be added to read list of select."
|
||||
return self.status in (WAIT_LEN, WAIT_MESSAGE)
|
||||
|
||||
@locked
|
||||
def is_closed(self):
|
||||
"Returns True if connection is closed."
|
||||
return self.status == CLOSED
|
||||
|
||||
def fileno(self):
|
||||
"Returns the file descriptor of the associated socket."
|
||||
return self.socket.fileno()
|
||||
|
||||
def close(self):
|
||||
"Closes connection"
|
||||
self.status = CLOSED
|
||||
self.socket.close()
|
||||
|
||||
class TNonblockingServer:
|
||||
"""Non-blocking server."""
|
||||
def __init__(self, processor, lsocket, inputProtocolFactory=None,
|
||||
outputProtocolFactory=None, threads=10):
|
||||
self.processor = processor
|
||||
self.socket = lsocket
|
||||
self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
|
||||
self.out_protocol = outputProtocolFactory or self.in_protocol
|
||||
self.threads = int(threads)
|
||||
self.clients = {}
|
||||
self.tasks = Queue.Queue()
|
||||
self._read, self._write = socket.socketpair()
|
||||
self.prepared = False
|
||||
|
||||
def setNumThreads(self, num):
|
||||
"""Set the number of worker threads that should be created."""
|
||||
# implement ThreadPool interface
|
||||
assert not self.prepared, "You can't change number of threads for working server"
|
||||
self.threads = num
|
||||
|
||||
def prepare(self):
|
||||
"""Prepares server for serve requests."""
|
||||
self.socket.listen()
|
||||
for _ in xrange(self.threads):
|
||||
thread = Worker(self.tasks)
|
||||
thread.setDaemon(True)
|
||||
thread.start()
|
||||
self.prepared = True
|
||||
|
||||
def wake_up(self):
|
||||
"""Wake up main thread.
|
||||
|
||||
The server usualy waits in select call in we should terminate one.
|
||||
The simplest way is using socketpair.
|
||||
|
||||
Select always wait to read from the first socket of socketpair.
|
||||
|
||||
In this case, we can just write anything to the second socket from
|
||||
socketpair."""
|
||||
self._write.send('1')
|
||||
|
||||
def _select(self):
|
||||
"""Does select on open connections."""
|
||||
readable = [self.socket.handle.fileno(), self._read.fileno()]
|
||||
writable = []
|
||||
for i, connection in self.clients.items():
|
||||
if connection.is_readable():
|
||||
readable.append(connection.fileno())
|
||||
if connection.is_writeable():
|
||||
writable.append(connection.fileno())
|
||||
if connection.is_closed():
|
||||
del self.clients[i]
|
||||
return select.select(readable, writable, readable)
|
||||
|
||||
def handle(self):
|
||||
"""Handle requests.
|
||||
|
||||
WARNING! You must call prepare BEFORE calling handle.
|
||||
"""
|
||||
assert self.prepared, "You have to call prepare before handle"
|
||||
rset, wset, xset = self._select()
|
||||
for readable in rset:
|
||||
if readable == self._read.fileno():
|
||||
# don't care i just need to clean readable flag
|
||||
self._read.recv(1024)
|
||||
elif readable == self.socket.handle.fileno():
|
||||
client = self.socket.accept().handle
|
||||
self.clients[client.fileno()] = Connection(client, self.wake_up)
|
||||
else:
|
||||
connection = self.clients[readable]
|
||||
connection.read()
|
||||
if connection.status == WAIT_PROCESS:
|
||||
itransport = TTransport.TMemoryBuffer(connection.message)
|
||||
otransport = TTransport.TMemoryBuffer()
|
||||
iprot = self.in_protocol.getProtocol(itransport)
|
||||
oprot = self.out_protocol.getProtocol(otransport)
|
||||
self.tasks.put([self.processor, iprot, oprot,
|
||||
otransport, connection.ready])
|
||||
for writeable in wset:
|
||||
self.clients[writeable].write()
|
||||
for oob in xset:
|
||||
self.clients[oob].close()
|
||||
del self.clients[oob]
|
||||
|
||||
def close(self):
|
||||
"""Closes the server."""
|
||||
for _ in xrange(self.threads):
|
||||
self.tasks.put([None, None, None, None, None])
|
||||
self.socket.close()
|
||||
self.prepared = False
|
||||
|
||||
def serve(self):
|
||||
"""Serve forever."""
|
||||
self.prepare()
|
||||
while True:
|
||||
self.handle()
|
@ -4,4 +4,4 @@
|
||||
# See accompanying file LICENSE or visit the Thrift site at:
|
||||
# http://developers.facebook.com/thrift/
|
||||
|
||||
__all__ = ['TServer']
|
||||
__all__ = ['TServer', 'TNonblockingServer']
|
||||
|
@ -9,12 +9,16 @@ import signal
|
||||
def relfile(fname):
|
||||
return os.path.join(os.path.dirname(__file__), fname)
|
||||
|
||||
FRAMED = ["TNonblockingServer"]
|
||||
|
||||
def runTest(server_class):
|
||||
print "Testing ", server_class
|
||||
serverproc = subprocess.Popen([sys.executable, relfile("TestServer.py"), server_class])
|
||||
try:
|
||||
|
||||
ret = subprocess.call([sys.executable, relfile("TestClient.py")])
|
||||
argv = [sys.executable, relfile("TestClient.py")]
|
||||
if server_class in FRAMED:
|
||||
argv.append('--framed')
|
||||
ret = subprocess.call(argv)
|
||||
if ret != 0:
|
||||
raise Exception("subprocess failed")
|
||||
finally:
|
||||
@ -25,4 +29,4 @@ def runTest(server_class):
|
||||
time.sleep(5)
|
||||
|
||||
map(runTest, ["TForkingServer", "TThreadPoolServer",
|
||||
"TThreadedServer", "TSimpleServer"])
|
||||
"TThreadedServer", "TSimpleServer", "TNonblockingServer"])
|
||||
|
@ -15,24 +15,29 @@ from optparse import OptionParser
|
||||
|
||||
|
||||
parser = OptionParser()
|
||||
parser.set_defaults(framed=False, verbose=1, host='localhost', port=9090)
|
||||
parser.add_option("--port", type="int", dest="port",
|
||||
help="connect to server at port")
|
||||
parser.add_option("--host", type="string", dest="host",
|
||||
help="connect to server")
|
||||
parser.add_option("--framed", action="store_true", dest="framed",
|
||||
help="use framed transport")
|
||||
parser.add_option('-v', '--verbose', action="store_const",
|
||||
dest="verbose", const=2,
|
||||
help="verbose output")
|
||||
parser.add_option('-q', '--quiet', action="store_const",
|
||||
dest="verbose", const=0,
|
||||
help="minimal output")
|
||||
|
||||
parser.add_option("--port", type="int", dest="port", default=9090)
|
||||
parser.add_option("--host", type="string", dest="host", default='localhost')
|
||||
parser.add_option("--framed-input", action="store_true", dest="framed_input")
|
||||
parser.add_option("--framed-output", action="store_false", dest="framed_output")
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
options, args = parser.parse_args()
|
||||
|
||||
class AbstractTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
global options
|
||||
|
||||
socket = TSocket.TSocket(options.host, options.port)
|
||||
|
||||
# Frame or buffer depending upon args
|
||||
if options.framed_input or options.framed_output:
|
||||
self.transport = TTransport.TFramedTransport(socket, options.framed_input, options.framed_output)
|
||||
if options.framed:
|
||||
self.transport = TTransport.TFramedTransport(socket)
|
||||
else:
|
||||
self.transport = TTransport.TBufferedTransport(socket)
|
||||
|
||||
@ -113,5 +118,13 @@ def suite():
|
||||
suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
|
||||
return suite
|
||||
|
||||
class OwnArgsTestProgram(unittest.TestProgram):
|
||||
def parseArgs(self, argv):
|
||||
if args:
|
||||
self.testNames = args
|
||||
else:
|
||||
self.testNames = (self.defaultTest,)
|
||||
self.createTests()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
|
||||
OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
|
||||
|
@ -9,7 +9,7 @@ from ThriftTest.ttypes import *
|
||||
from thrift.transport import TTransport
|
||||
from thrift.transport import TSocket
|
||||
from thrift.protocol import TBinaryProtocol
|
||||
from thrift.server import TServer
|
||||
from thrift.server import TServer, TNonblockingServer
|
||||
|
||||
class TestHandler:
|
||||
|
||||
@ -59,13 +59,33 @@ class TestHandler:
|
||||
time.sleep(seconds)
|
||||
print 'done sleeping'
|
||||
|
||||
def testNest(self, thing):
|
||||
return thing
|
||||
|
||||
def testMap(self, thing):
|
||||
return thing
|
||||
|
||||
def testSet(self, thing):
|
||||
return thing
|
||||
|
||||
def testList(self, thing):
|
||||
return thing
|
||||
|
||||
def testEnum(self, thing):
|
||||
return thing
|
||||
|
||||
def testTypedef(self, thing):
|
||||
return thing
|
||||
|
||||
handler = TestHandler()
|
||||
processor = ThriftTest.Processor(handler)
|
||||
transport = TSocket.TServerSocket(9090)
|
||||
tfactory = TTransport.TBufferedTransportFactory()
|
||||
pfactory = TBinaryProtocol.TBinaryProtocolFactory()
|
||||
|
||||
ServerClass = getattr(TServer, sys.argv[1])
|
||||
|
||||
server = ServerClass(processor, transport, tfactory, pfactory)
|
||||
if sys.argv[1] == "TNonblockingServer":
|
||||
server = TNonblockingServer.TNonblockingServer(processor, transport)
|
||||
else:
|
||||
ServerClass = getattr(TServer, sys.argv[1])
|
||||
server = ServerClass(processor, transport, tfactory, pfactory)
|
||||
server.serve()
|
||||
|
Loading…
Reference in New Issue
Block a user