THRIFT-2231 Support tornado-4.x (Python)

Client: Python
Patch: Roey Berman
Signed-off-by: Roger Meier <roger@apache.org>
This commit is contained in:
Roger Meier 2014-08-07 17:03:47 +02:00
parent f5136abeba
commit d52edba00f
3 changed files with 275 additions and 310 deletions

View File

@ -222,17 +222,6 @@ class t_py_generator : public t_generator {
void generate_python_docstring (std::ofstream& out, void generate_python_docstring (std::ofstream& out,
t_doc* tdoc); t_doc* tdoc);
/**
* a type for specifying to function_signature what type of Tornado callback
* parameter to add
*/
enum tornado_callback_t {
NONE = 0,
MANDATORY_FOR_ONEWAY_ELSE_NONE = 1,
OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY = 2,
};
/** /**
* Helper rendering functions * Helper rendering functions
*/ */
@ -245,8 +234,7 @@ class t_py_generator : public t_generator {
std::string render_field_default_value(t_field* tfield); std::string render_field_default_value(t_field* tfield);
std::string type_name(t_type* ttype); std::string type_name(t_type* ttype);
std::string function_signature(t_function* tfunction, std::string function_signature(t_function* tfunction,
bool interface=false, bool interface=false);
tornado_callback_t callback=NONE);
std::string argument_list(t_struct* tstruct, std::string argument_list(t_struct* tstruct,
std::vector<std::string> *pre=NULL, std::vector<std::string> *pre=NULL,
std::vector<std::string> *post=NULL); std::vector<std::string> *post=NULL);
@ -1067,7 +1055,8 @@ void t_py_generator::generate_service(t_service* tservice) {
"from thrift.transport import TTwisted" << endl; "from thrift.transport import TTwisted" << endl;
} else if (gen_tornado_) { } else if (gen_tornado_) {
f_service_ << "from tornado import gen" << endl; f_service_ << "from tornado import gen" << endl;
f_service_ << "from tornado import stack_context" << endl; f_service_ << "from tornado import concurrent" << endl;
f_service_ << "from thrift.transport import TTransport" << endl;
} }
f_service_ << endl; f_service_ << endl;
@ -1156,7 +1145,7 @@ void t_py_generator::generate_service_interface(t_service* tservice) {
vector<t_function*>::iterator f_iter; vector<t_function*>::iterator f_iter;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
f_service_ << f_service_ <<
indent() << "def " << function_signature(*f_iter, true, OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) << ":" << endl; indent() << "def " << function_signature(*f_iter, true) << ":" << endl;
indent_up(); indent_up();
generate_python_docstring(f_service_, (*f_iter)); generate_python_docstring(f_service_, (*f_iter));
f_service_ << f_service_ <<
@ -1229,6 +1218,7 @@ void t_py_generator::generate_service_client(t_service* tservice) {
indent() << " else iprot_factory)" << endl << indent() << " else iprot_factory)" << endl <<
indent() << " self._seqid = 0" << endl << indent() << " self._seqid = 0" << endl <<
indent() << " self._reqs = {}" << endl << indent() << " self._reqs = {}" << endl <<
indent() << " self._transport.io_loop.spawn_callback(self._start_receiving)" << endl <<
endl; endl;
} else { } else {
f_service_ << f_service_ <<
@ -1257,18 +1247,29 @@ void t_py_generator::generate_service_client(t_service* tservice) {
if (gen_tornado_ && extends.empty()) { if (gen_tornado_ && extends.empty()) {
f_service_ << f_service_ <<
indent() << "@gen.engine" << endl << indent() << "@gen.engine" << endl <<
indent() << "def recv_dispatch(self):" << endl << indent() << "def _start_receiving(self):" << endl <<
indent() << " \"\"\"read a response from the wire. schedule exactly one per send that" << endl << indent() << " while True:" << endl <<
indent() << " expects a response, but it doesn't matter which callee gets which" << endl << indent() << " try:" << endl <<
indent() << " response; they're dispatched here properly\"\"\"" << endl << indent() << " frame = yield self._transport.readFrame()" << endl <<
endl << indent() << " except TTransport.TTransportException as e:" << endl <<
indent() << " # wait for a frame header" << endl << indent() << " for future in self._reqs.itervalues():" << endl <<
indent() << " frame = yield gen.Task(self._transport.readFrame)" << endl << indent() << " future.set_exception(e)" << endl <<
indent() << " tr = TTransport.TMemoryBuffer(frame)" << endl << indent() << " self._reqs = {}" << endl <<
indent() << " iprot = self._iprot_factory.getProtocol(tr)" << endl << indent() << " return" << endl <<
indent() << " (fname, mtype, rseqid) = iprot.readMessageBegin()" << endl << indent() << " tr = TTransport.TMemoryBuffer(frame)" << endl <<
indent() << " method = getattr(self, 'recv_' + fname)" << endl << indent() << " iprot = self._iprot_factory.getProtocol(tr)" << endl <<
indent() << " method(iprot, mtype, rseqid)" << endl << indent() << " (fname, mtype, rseqid) = iprot.readMessageBegin()" << endl <<
indent() << " method = getattr(self, 'recv_' + fname)" << endl <<
indent() << " future = self._reqs.pop(rseqid, None)" << endl <<
indent() << " if not future:" << endl <<
indent() << " # future has already been discarded" << endl <<
indent() << " continue" << endl <<
indent() << " try:" << endl <<
indent() << " result = method(iprot, mtype, rseqid)" << endl <<
indent() << " except Exception as e:" << endl <<
indent() << " future.set_exception(e)" << endl <<
indent() << " else:" << endl <<
indent() << " future.set_result(result)" << endl <<
endl; endl;
} }
@ -1283,7 +1284,7 @@ void t_py_generator::generate_service_client(t_service* tservice) {
// Open function // Open function
indent(f_service_) << indent(f_service_) <<
"def " << function_signature(*f_iter, false, OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) << ":" << endl; "def " << function_signature(*f_iter, false) << ":" << endl;
indent_up(); indent_up();
generate_python_docstring(f_service_, (*f_iter)); generate_python_docstring(f_service_, (*f_iter));
if (gen_twisted_) { if (gen_twisted_) {
@ -1296,7 +1297,7 @@ void t_py_generator::generate_service_client(t_service* tservice) {
indent(f_service_) << "self._seqid += 1" << endl; indent(f_service_) << "self._seqid += 1" << endl;
if (!(*f_iter)->is_oneway()) { if (!(*f_iter)->is_oneway()) {
indent(f_service_) << indent(f_service_) <<
"self._reqs[self._seqid] = callback" << endl; "future = self._reqs[self._seqid] = concurrent.Future()" << endl;
} }
} }
@ -1313,15 +1314,6 @@ void t_py_generator::generate_service_client(t_service* tservice) {
f_service_ << (*fld_iter)->get_name(); f_service_ << (*fld_iter)->get_name();
} }
if (gen_tornado_ && (*f_iter)->is_oneway()) {
if (first) {
first = false;
} else {
f_service_ << ", ";
}
f_service_ << "callback";
}
f_service_ << ")" << endl; f_service_ << ")" << endl;
if (!(*f_iter)->is_oneway()) { if (!(*f_iter)->is_oneway()) {
@ -1329,7 +1321,7 @@ void t_py_generator::generate_service_client(t_service* tservice) {
if (gen_twisted_) { if (gen_twisted_) {
f_service_ << "return d" << endl; f_service_ << "return d" << endl;
} else if (gen_tornado_) { } else if (gen_tornado_) {
f_service_ << "self.recv_dispatch()" << endl; f_service_ << "return future" << endl;
} else { } else {
if (!(*f_iter)->get_returntype()->is_void()) { if (!(*f_iter)->get_returntype()->is_void()) {
f_service_ << "return "; f_service_ << "return ";
@ -1347,7 +1339,7 @@ void t_py_generator::generate_service_client(t_service* tservice) {
f_service_ << endl; f_service_ << endl;
indent(f_service_) << indent(f_service_) <<
"def send_" << function_signature(*f_iter, false, MANDATORY_FOR_ONEWAY_ELSE_NONE) << ":" << endl; "def send_" << function_signature(*f_iter, false) << ":" << endl;
indent_up(); indent_up();
@ -1374,24 +1366,11 @@ void t_py_generator::generate_service_client(t_service* tservice) {
} }
// Write to the stream // Write to the stream
if (gen_twisted_) { if (gen_twisted_ || gen_tornado_) {
f_service_ << f_service_ <<
indent() << "args.write(oprot)" << endl << indent() << "args.write(oprot)" << endl <<
indent() << "oprot.writeMessageEnd()" << endl << indent() << "oprot.writeMessageEnd()" << endl <<
indent() << "oprot.trans.flush()" << endl; indent() << "oprot.trans.flush()" << endl;
} else if (gen_tornado_) {
f_service_ <<
indent() << "args.write(oprot)" << endl <<
indent() << "oprot.writeMessageEnd()" << endl;
if ((*f_iter)->is_oneway()) {
// send_* carry the callback so you can block on the write's flush
// (rather than on receipt of the response)
f_service_ <<
indent() << "oprot.trans.flush(callback=callback)" << endl;
} else {
f_service_ <<
indent() << "oprot.trans.flush()" << endl;
}
} else { } else {
f_service_ << f_service_ <<
indent() << "args.write(self._oprot)" << endl << indent() << "args.write(self._oprot)" << endl <<
@ -1426,11 +1405,10 @@ void t_py_generator::generate_service_client(t_service* tservice) {
f_service_ << f_service_ <<
indent() << "d = self._reqs.pop(rseqid)" << endl; indent() << "d = self._reqs.pop(rseqid)" << endl;
} else if (gen_tornado_) { } else if (gen_tornado_) {
f_service_ <<
indent() << "callback = self._reqs.pop(rseqid)" << endl;
} else { } else {
f_service_ << f_service_ <<
indent() << "(fname, mtype, rseqid) = self._iprot.readMessageBegin()" << endl; indent() << "iprot = self._iprot" << endl <<
indent() << "(fname, mtype, rseqid) = iprot.readMessageBegin()" << endl;
} }
f_service_ << f_service_ <<
@ -1445,23 +1423,14 @@ void t_py_generator::generate_service_client(t_service* tservice) {
indent() << "result = " << resultname << "()" << endl << indent() << "result = " << resultname << "()" << endl <<
indent() << "result.read(iprot)" << endl << indent() << "result.read(iprot)" << endl <<
indent() << "iprot.readMessageEnd()" << endl; indent() << "iprot.readMessageEnd()" << endl;
} else if (gen_tornado_) { } else {
f_service_ << f_service_ <<
indent() << " x.read(iprot)" << endl << indent() << " x.read(iprot)" << endl <<
indent() << " iprot.readMessageEnd()" << endl << indent() << " iprot.readMessageEnd()" << endl <<
indent() << " callback(x)" << endl << indent() << " raise x" << endl <<
indent() << " return" << endl <<
indent() << "result = " << resultname << "()" << endl << indent() << "result = " << resultname << "()" << endl <<
indent() << "result.read(iprot)" << endl << indent() << "result.read(iprot)" << endl <<
indent() << "iprot.readMessageEnd()" << endl; indent() << "iprot.readMessageEnd()" << endl;
} else {
f_service_ <<
indent() << " x.read(self._iprot)" << endl <<
indent() << " self._iprot.readMessageEnd()" << endl <<
indent() << " raise x" << endl <<
indent() << "result = " << resultname << "()" << endl <<
indent() << "result.read(self._iprot)" << endl <<
indent() << "self._iprot.readMessageEnd()" << endl;
} }
// Careful, only return _result if not a void function // Careful, only return _result if not a void function
@ -1471,10 +1440,6 @@ void t_py_generator::generate_service_client(t_service* tservice) {
if (gen_twisted_) { if (gen_twisted_) {
f_service_ << f_service_ <<
indent() << " return d.callback(result.success)" << endl; indent() << " return d.callback(result.success)" << endl;
} else if (gen_tornado_) {
f_service_ <<
indent() << " callback(result.success)" << endl <<
indent() << " return" << endl;
} else { } else {
f_service_ << f_service_ <<
indent() << " return result.success" << endl; indent() << " return result.success" << endl;
@ -1490,11 +1455,6 @@ void t_py_generator::generate_service_client(t_service* tservice) {
if (gen_twisted_) { if (gen_twisted_) {
f_service_ << f_service_ <<
indent() << " return d.errback(result." << (*x_iter)->get_name() << ")" << endl; indent() << " return d.errback(result." << (*x_iter)->get_name() << ")" << endl;
} else if (gen_tornado_) {
f_service_ <<
indent() << " callback(result." << (*x_iter)->get_name() << ")" << endl <<
indent() << " return" << endl;
} else { } else {
f_service_ << f_service_ <<
indent() << " raise result." << (*x_iter)->get_name() << "" << endl; indent() << " raise result." << (*x_iter)->get_name() << "" << endl;
@ -1506,10 +1466,6 @@ void t_py_generator::generate_service_client(t_service* tservice) {
if (gen_twisted_) { if (gen_twisted_) {
f_service_ << f_service_ <<
indent() << "return d.callback(None)" << endl; indent() << "return d.callback(None)" << endl;
} else if (gen_tornado_) {
f_service_ <<
indent() << "callback(None)" << endl <<
indent() << "return" << endl;
} else { } else {
f_service_ << f_service_ <<
indent() << "return" << endl; indent() << "return" << endl;
@ -1518,10 +1474,6 @@ void t_py_generator::generate_service_client(t_service* tservice) {
if (gen_twisted_) { if (gen_twisted_) {
f_service_ << f_service_ <<
indent() << "return d.errback(TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\"))" << endl; indent() << "return d.errback(TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\"))" << endl;
} else if (gen_tornado_) {
f_service_ <<
indent() << "callback(TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\"))" << endl <<
indent() << "return" << endl;
} else { } else {
f_service_ << f_service_ <<
indent() << "raise TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl; indent() << "raise TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl;
@ -1785,22 +1737,9 @@ void t_py_generator::generate_service_server(t_service* tservice) {
f_service_ << endl; f_service_ << endl;
// Generate the server implementation // Generate the server implementation
if (gen_tornado_) { f_service_ <<
f_service_ << indent() << "def process(self, iprot, oprot):" << endl;
indent() << "@gen.engine" << endl << indent_up();
indent() << "def process(self, transport, iprot_factory, oprot, callback):" << endl;
indent_up();
f_service_ <<
indent() << "# wait for a frame header" << endl <<
indent() << "frame = yield gen.Task(transport.readFrame)" << endl <<
indent() << "tr = TTransport.TMemoryBuffer(frame)" << endl <<
indent() << "iprot = iprot_factory.getProtocol(tr)" << endl <<
endl;
} else {
f_service_ <<
indent() << "def process(self, iprot, oprot):" << endl;
indent_up();
}
f_service_ << f_service_ <<
indent() << "(name, type, seqid) = iprot.readMessageBegin()" << endl; indent() << "(name, type, seqid) = iprot.readMessageBegin()" << endl;
@ -1821,8 +1760,6 @@ void t_py_generator::generate_service_server(t_service* tservice) {
if (gen_twisted_) { if (gen_twisted_) {
f_service_ << f_service_ <<
indent() << " return defer.succeed(None)" << endl; indent() << " return defer.succeed(None)" << endl;
} else if (gen_tornado_) {
// nothing
} else { } else {
f_service_ << f_service_ <<
indent() << " return" << endl; indent() << " return" << endl;
@ -1831,13 +1768,9 @@ void t_py_generator::generate_service_server(t_service* tservice) {
f_service_ << f_service_ <<
indent() << "else:" << endl; indent() << "else:" << endl;
if (gen_twisted_) { if (gen_twisted_ || gen_tornado_) {
f_service_ << f_service_ <<
indent() << " return self._processMap[name](self, seqid, iprot, oprot)" << endl; indent() << " return self._processMap[name](self, seqid, iprot, oprot)" << endl;
} else if (gen_tornado_) {
f_service_ <<
indent() << " yield gen.Task(self._processMap[name], self, seqid, iprot, oprot)" << endl <<
indent() << "callback()" << endl;
} else { } else {
f_service_ << f_service_ <<
indent() << " self._processMap[name](self, seqid, iprot, oprot)" << endl; indent() << " self._processMap[name](self, seqid, iprot, oprot)" << endl;
@ -1870,9 +1803,9 @@ void t_py_generator::generate_process_function(t_service* tservice,
// Open function // Open function
if (gen_tornado_) { if (gen_tornado_) {
f_service_ << f_service_ <<
indent() << "@gen.engine" << endl << indent() << "@gen.coroutine" << endl <<
indent() << "def process_" << tfunction->get_name() << indent() << "def process_" << tfunction->get_name() <<
"(self, seqid, iprot, oprot, callback):" << endl; "(self, seqid, iprot, oprot):" << endl;
} else { } else {
f_service_ << f_service_ <<
indent() << "def process_" << tfunction->get_name() << indent() << "def process_" << tfunction->get_name() <<
@ -1996,6 +1929,7 @@ void t_py_generator::generate_process_function(t_service* tservice,
} }
} else if (gen_tornado_) { } else if (gen_tornado_) {
/*
if (!tfunction->is_oneway() && xceptions.size() > 0) { if (!tfunction->is_oneway() && xceptions.size() > 0) {
f_service_ << f_service_ <<
endl << endl <<
@ -2014,21 +1948,27 @@ void t_py_generator::generate_process_function(t_service* tservice,
f_service_ << f_service_ <<
endl << endl <<
indent() << "with stack_context.ExceptionStackContext(handle_exception):" << endl; indent() << "try:" << endl;
indent_up(); indent_up();
} }
*/
// Generate the function call // Generate the function call
t_struct* arg_struct = tfunction->get_arglist(); t_struct* arg_struct = tfunction->get_arglist();
const std::vector<t_field*>& fields = arg_struct->get_members(); const std::vector<t_field*>& fields = arg_struct->get_members();
vector<t_field*>::const_iterator f_iter; vector<t_field*>::const_iterator f_iter;
if (xceptions.size() > 0) {
f_service_ <<
indent() << "try:" << endl;
indent_up();
}
f_service_ << indent(); f_service_ << indent();
if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) {
f_service_ << "result.success = "; f_service_ << "result.success = ";
} }
f_service_ << f_service_ <<
"yield gen.Task(self._handler." << tfunction->get_name() << ", "; "yield gen.maybe_future(self._handler." << tfunction->get_name() << "(";
bool first = true; bool first = true;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
if (first) { if (first) {
@ -2038,22 +1978,27 @@ void t_py_generator::generate_process_function(t_service* tservice,
} }
f_service_ << "args." << (*f_iter)->get_name(); f_service_ << "args." << (*f_iter)->get_name();
} }
f_service_ << ")" << endl; f_service_ << "))" << endl;
if (xceptions.size() > 0) {
f_service_ << endl;
}
if (!tfunction->is_oneway() && xceptions.size() > 0) { if (!tfunction->is_oneway() && xceptions.size() > 0) {
indent_down(); indent_down();
for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
f_service_ <<
indent() << "except " << type_name((*x_iter)->get_type()) << ", " << (*x_iter)->get_name() << ":" << endl;
if (!tfunction->is_oneway()) {
indent_up();
f_service_ <<
indent() << "result." << (*x_iter)->get_name() << " = " << (*x_iter)->get_name() << endl;
indent_down();
} else {
f_service_ <<
indent() << "pass" << endl;
}
}
} }
// Shortcut out here for oneway functions // Shortcut out here for oneway functions
if (tfunction->is_oneway()) { if (tfunction->is_oneway()) {
f_service_ <<
indent() << "callback()" << endl;
indent_down();
f_service_ << endl;
return; return;
} }
@ -2061,8 +2006,7 @@ void t_py_generator::generate_process_function(t_service* tservice,
indent() << "oprot.writeMessageBegin(\"" << tfunction->get_name() << "\", TMessageType.REPLY, seqid)" << endl << indent() << "oprot.writeMessageBegin(\"" << tfunction->get_name() << "\", TMessageType.REPLY, seqid)" << endl <<
indent() << "result.write(oprot)" << endl << indent() << "result.write(oprot)" << endl <<
indent() << "oprot.writeMessageEnd()" << endl << indent() << "oprot.writeMessageEnd()" << endl <<
indent() << "oprot.trans.flush()" << endl << indent() << "oprot.trans.flush()" << endl;
indent() << "callback()" << endl;
// Close function // Close function
indent_down(); indent_down();
@ -2621,8 +2565,7 @@ string t_py_generator::render_field_default_value(t_field* tfield) {
* @return String of rendered function definition * @return String of rendered function definition
*/ */
string t_py_generator::function_signature(t_function* tfunction, string t_py_generator::function_signature(t_function* tfunction,
bool interface, bool interface) {
tornado_callback_t callback) {
vector<string> pre; vector<string> pre;
vector<string> post; vector<string> post;
string signature = tfunction->get_name() + "("; string signature = tfunction->get_name() + "(";
@ -2631,22 +2574,6 @@ string t_py_generator::function_signature(t_function* tfunction,
pre.push_back("self"); pre.push_back("self");
} }
if (gen_tornado_) {
if (callback == NONE) {
} else if (callback == MANDATORY_FOR_ONEWAY_ELSE_NONE) {
if (tfunction->is_oneway()) {
// Tornado send_* carry the callback so you can block on the write's flush
// (rather than on receipt of the response)
post.push_back("callback");
}
} else if (callback == OPTIONAL_FOR_ONEWAY_ELSE_MANDATORY) {
if (tfunction->is_oneway()) {
post.push_back("callback=None");
} else {
post.push_back("callback");
}
}
}
signature += argument_list(tfunction->get_arglist(), &pre, &post) + ")"; signature += argument_list(tfunction->get_arglist(), &pre, &post) + ")";
return signature; return signature;
} }

View File

@ -17,58 +17,91 @@
# under the License. # under the License.
# #
from cStringIO import StringIO from __future__ import absolute_import
import logging import logging
import socket import socket
import struct import struct
from thrift.transport import TTransport from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
from thrift.transport.TTransport import TTransportException
from tornado import gen from io import BytesIO
from tornado import iostream from collections import deque
from tornado import tcpserver from contextlib import contextmanager
from tornado import gen, iostream, ioloop, tcpserver, concurrent
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
class TTornadoStreamTransport(TTransport.TTransportBase): class _Lock(object):
def __init__(self):
self._waiters = deque()
def acquired(self):
return len(self._waiters) > 0
@gen.coroutine
def acquire(self):
blocker = self._waiters[-1] if self.acquired() else None
future = concurrent.Future()
self._waiters.append(future)
if blocker:
yield blocker
raise gen.Return(self._lock_context())
def release(self):
assert self.acquired(), 'Lock not aquired'
future = self._waiters.popleft()
future.set_result(None)
@contextmanager
def _lock_context(self):
try:
yield
finally:
self.release()
class TTornadoStreamTransport(TTransportBase):
"""a framed, buffered transport over a Tornado stream""" """a framed, buffered transport over a Tornado stream"""
def __init__(self, host, port, stream=None): def __init__(self, host, port, stream=None, io_loop=None):
self.host = host self.host = host
self.port = port self.port = port
self.is_queuing_reads = False self.io_loop = io_loop or ioloop.IOLoop.current()
self.read_queue = [] self.__wbuf = BytesIO()
self.__wbuf = StringIO() self._read_lock = _Lock()
# servers provide a ready-to-go stream # servers provide a ready-to-go stream
self.stream = stream self.stream = stream
if self.stream is not None:
self._set_close_callback()
# not the same number of parameters as TTransportBase.open def with_timeout(self, timeout, future):
def open(self, callback): return gen.with_timeout(timeout, future, self.io_loop)
@gen.coroutine
def open(self, timeout=None):
logging.debug('socket connecting') logging.debug('socket connecting')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.stream = iostream.IOStream(sock) self.stream = iostream.IOStream(sock)
def on_close_in_connect(*_): try:
message = 'could not connect to {}:{}'.format(self.host, self.port) connect = self.stream.connect((self.host, self.port))
if timeout is not None:
yield self.with_timeout(timeout, connect)
else:
yield connect
except (socket.error, IOError, ioloop.TimeoutError) as e:
message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
raise TTransportException( raise TTransportException(
type=TTransportException.NOT_OPEN, type=TTransportException.NOT_OPEN,
message=message) message=message)
self.stream.set_close_callback(on_close_in_connect)
def finish(*_): raise gen.Return(self)
self._set_close_callback()
callback()
self.stream.connect((self.host, self.port), callback=finish) def set_close_callback(self, callback):
"""
def _set_close_callback(self): Should be called only after open() returns
def on_close(): """
raise TTransportException( self.stream.set_close_callback(callback)
type=TTransportException.END_OF_FILE,
message='socket closed')
self.stream.set_close_callback(self.close)
def close(self): def close(self):
# don't raise if we intend to close # don't raise if we intend to close
@ -78,51 +111,45 @@ class TTornadoStreamTransport(TTransport.TTransportBase):
def read(self, _): def read(self, _):
# The generated code for Tornado shouldn't do individual reads -- only # The generated code for Tornado shouldn't do individual reads -- only
# frames at a time # frames at a time
assert "you're doing it wrong" is True assert False, "you're doing it wrong"
@gen.engine @contextmanager
def readFrame(self, callback): def io_exception_context(self):
self.read_queue.append(callback) try:
logging.debug('read queue: %s', self.read_queue) yield
except (socket.error, IOError) as e:
raise TTransportException(
type=TTransportException.END_OF_FILE,
message=str(e))
except iostream.StreamBufferFullError as e:
raise TTransportException(
type=TTransportException.UNKNOWN,
message=str(e))
if self.is_queuing_reads: @gen.coroutine
# If a read is already in flight, then the while loop below should def readFrame(self):
# pull it from self.read_queue # IOStream processes reads one at a time
return with (yield self._read_lock.acquire()):
with self.io_exception_context():
self.is_queuing_reads = True frame_header = yield self.stream.read_bytes(4)
while self.read_queue: if len(frame_header) == 0:
next_callback = self.read_queue.pop() raise iostream.StreamClosedError('Read zero bytes from stream')
result = yield gen.Task(self._readFrameFromStream) frame_length, = struct.unpack('!i', frame_header)
next_callback(result) logging.debug('received frame header, frame length = %d', frame_length)
self.is_queuing_reads = False frame = yield self.stream.read_bytes(frame_length)
logging.debug('received frame payload: %r', frame)
@gen.engine raise gen.Return(frame)
def _readFrameFromStream(self, callback):
logging.debug('_readFrameFromStream')
frame_header = yield gen.Task(self.stream.read_bytes, 4)
frame_length, = struct.unpack('!i', frame_header)
logging.debug('received frame header, frame length = %i', frame_length)
frame = yield gen.Task(self.stream.read_bytes, frame_length)
logging.debug('received frame payload')
callback(frame)
def write(self, buf): def write(self, buf):
self.__wbuf.write(buf) self.__wbuf.write(buf)
def flush(self, callback=None): def flush(self):
wout = self.__wbuf.getvalue() frame = self.__wbuf.getvalue()
wsz = len(wout)
# reset wbuf before write/flush to preserve state on underlying failure # reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = StringIO() frame_length = struct.pack('!i', len(frame))
# N.B.: Doing this string concatenation is WAY cheaper than making self.__wbuf = BytesIO()
# two separate calls to the underlying socket object. Socket writes in with self.io_exception_context():
# Python turn out to be REALLY expensive, but it seems to do a pretty return self.stream.write(frame_length + frame)
# good job of managing string buffer operations without excessive copies
buf = struct.pack("!i", wsz) + wout
logging.debug('writing frame length = %i', wsz)
self.stream.write(buf, callback)
class TTornadoServer(tcpserver.TCPServer): class TTornadoServer(tcpserver.TCPServer):
@ -135,19 +162,21 @@ class TTornadoServer(tcpserver.TCPServer):
self._oprot_factory = (oprot_factory if oprot_factory is not None self._oprot_factory = (oprot_factory if oprot_factory is not None
else iprot_factory) else iprot_factory)
@gen.coroutine
def handle_stream(self, stream, address): def handle_stream(self, stream, address):
host, port = address
trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
io_loop=self.io_loop)
oprot = self._oprot_factory.getProtocol(trans)
try: try:
host, port = address while not trans.stream.closed():
trans = TTornadoStreamTransport(host=host, port=port, stream=stream) frame = yield trans.readFrame()
oprot = self._oprot_factory.getProtocol(trans) tr = TMemoryBuffer(frame)
iprot = self._iprot_factory.getProtocol(tr)
def next_pass(): yield self._processor.process(iprot, oprot)
if not trans.stream.closed():
self._processor.process(trans, self._iprot_factory, oprot,
callback=next_pass)
next_pass()
except Exception: except Exception:
logging.exception('thrift exception in handle_stream') logging.exception('thrift exception in handle_stream')
trans.close() trans.close()
logging.info('client disconnected %s:%d', host, port)

View File

@ -22,11 +22,13 @@
import datetime import datetime
import glob import glob
import sys import sys
import os
import time import time
import unittest import unittest
sys.path.insert(0, './gen-py.tornado') basepath = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) sys.path.insert(0, basepath+'/gen-py.tornado')
sys.path.insert(0, glob.glob(os.path.join(basepath, '../../lib/py/build/lib.*'))[0])
try: try:
__import__('tornado') __import__('tornado')
@ -34,11 +36,12 @@ except ImportError:
print "module `tornado` not found, skipping test" print "module `tornado` not found, skipping test"
sys.exit(0) sys.exit(0)
from tornado import gen, ioloop, stack_context from tornado import gen
from tornado.testing import AsyncTestCase, get_unused_port from tornado.testing import AsyncTestCase, get_unused_port, gen_test
from thrift import TTornado from thrift import TTornado
from thrift.protocol import TBinaryProtocol from thrift.protocol import TBinaryProtocol
from thrift.transport.TTransport import TTransportException
from ThriftTest import ThriftTest from ThriftTest import ThriftTest
from ThriftTest.ttypes import * from ThriftTest.ttypes import *
@ -48,31 +51,31 @@ class TestHandler(object):
def __init__(self, test_instance): def __init__(self, test_instance):
self.test_instance = test_instance self.test_instance = test_instance
def testVoid(self, callback): def testVoid(self):
callback() pass
def testString(self, s, callback): def testString(self, s):
callback(s) return s
def testByte(self, b, callback): def testByte(self, b):
callback(b) return b
def testI16(self, i16, callback): def testI16(self, i16):
callback(i16) return i16
def testI32(self, i32, callback): def testI32(self, i32):
callback(i32) return i32
def testI64(self, i64, callback): def testI64(self, i64):
callback(i64) return i64
def testDouble(self, dub, callback): def testDouble(self, dub):
callback(dub) return dub
def testStruct(self, thing, callback): def testStruct(self, thing):
callback(thing) return thing
def testException(self, s, callback): def testException(self, s):
if s == 'Xception': if s == 'Xception':
x = Xception() x = Xception()
x.errorCode = 1001 x.errorCode = 1001
@ -80,133 +83,139 @@ class TestHandler(object):
raise x raise x
elif s == 'throw_undeclared': elif s == 'throw_undeclared':
raise ValueError("foo") raise ValueError("foo")
callback()
def testOneway(self, seconds, callback=None): def testOneway(self, seconds):
start = time.time() start = time.time()
def fire_oneway(): def fire_oneway():
end = time.time() end = time.time()
self.test_instance.stop((start, end, seconds)) self.test_instance.stop((start, end, seconds))
ioloop.IOLoop.instance().add_timeout( self.test_instance.io_loop.add_timeout(
datetime.timedelta(seconds=seconds), datetime.timedelta(seconds=seconds),
fire_oneway) fire_oneway)
if callback: def testNest(self, thing):
callback() return thing
def testNest(self, thing, callback): @gen.coroutine
callback(thing) def testMap(self, thing):
yield gen.moment
raise gen.Return(thing)
def testMap(self, thing, callback): def testSet(self, thing):
callback(thing) return thing
def testSet(self, thing, callback): def testList(self, thing):
callback(thing) return thing
def testList(self, thing, callback): def testEnum(self, thing):
callback(thing) return thing
def testEnum(self, thing, callback): def testTypedef(self, thing):
callback(thing) return thing
def testTypedef(self, thing, callback):
callback(thing)
class ThriftTestCase(AsyncTestCase): class ThriftTestCase(AsyncTestCase):
def get_new_ioloop(self):
return ioloop.IOLoop.instance()
def setUp(self): def setUp(self):
super(ThriftTestCase, self).setUp()
self.port = get_unused_port() self.port = get_unused_port()
self.io_loop = self.get_new_ioloop()
# server # server
self.handler = TestHandler(self) self.handler = TestHandler(self)
self.processor = ThriftTest.Processor(self.handler) self.processor = ThriftTest.Processor(self.handler)
self.pfactory = TBinaryProtocol.TBinaryProtocolFactory() self.pfactory = TBinaryProtocol.TBinaryProtocolFactory()
self.server = TTornado.TTornadoServer(self.processor, self.pfactory) self.server = TTornado.TTornadoServer(self.processor, self.pfactory, io_loop=self.io_loop)
self.server.bind(self.port) self.server.bind(self.port)
self.server.start(1) self.server.start(1)
# client # client
transport = TTornado.TTornadoStreamTransport('localhost', self.port) transport = TTornado.TTornadoStreamTransport('localhost', self.port, io_loop=self.io_loop)
pfactory = TBinaryProtocol.TBinaryProtocolFactory() pfactory = TBinaryProtocol.TBinaryProtocolFactory()
self.io_loop.run_sync(transport.open)
self.client = ThriftTest.Client(transport, pfactory) self.client = ThriftTest.Client(transport, pfactory)
transport.open(callback=self.stop)
self.wait(timeout=1)
@gen_test
def test_void(self): def test_void(self):
self.client.testVoid(callback=self.stop) v = yield self.client.testVoid()
v = self.wait(timeout=1) self.assertEqual(v, None)
self.assertEquals(v, None)
@gen_test
def test_string(self): def test_string(self):
self.client.testString('Python', callback=self.stop) v = yield self.client.testString('Python')
v = self.wait(timeout=1) self.assertEqual(v, 'Python')
self.assertEquals(v, 'Python')
@gen_test
def test_byte(self): def test_byte(self):
self.client.testByte(63, callback=self.stop) v = yield self.client.testByte(63)
v = self.wait(timeout=1) self.assertEqual(v, 63)
self.assertEquals(v, 63)
@gen_test
def test_i32(self): def test_i32(self):
self.client.testI32(-1, callback=self.stop) v = yield self.client.testI32(-1)
v = self.wait(timeout=1) self.assertEqual(v, -1)
self.assertEquals(v, -1)
self.client.testI32(0, callback=self.stop) v = yield self.client.testI32(0)
v = self.wait(timeout=1) self.assertEqual(v, 0)
self.assertEquals(v, 0)
@gen_test
def test_i64(self): def test_i64(self):
self.client.testI64(-34359738368, callback=self.stop) v = yield self.client.testI64(-34359738368)
v = self.wait(timeout=1) self.assertEqual(v, -34359738368)
self.assertEquals(v, -34359738368)
@gen_test
def test_double(self): def test_double(self):
self.client.testDouble(-5.235098235, callback=self.stop) v = yield self.client.testDouble(-5.235098235)
v = self.wait(timeout=1) self.assertEqual(v, -5.235098235)
self.assertEquals(v, -5.235098235)
@gen_test
def test_struct(self): def test_struct(self):
x = Xtruct() x = Xtruct()
x.string_thing = "Zero" x.string_thing = "Zero"
x.byte_thing = 1 x.byte_thing = 1
x.i32_thing = -3 x.i32_thing = -3
x.i64_thing = -5 x.i64_thing = -5
self.client.testStruct(x, callback=self.stop) y = yield self.client.testStruct(x)
y = self.wait(timeout=1) self.assertEqual(y.string_thing, "Zero")
self.assertEquals(y.string_thing, "Zero") self.assertEqual(y.byte_thing, 1)
self.assertEquals(y.byte_thing, 1) self.assertEqual(y.i32_thing, -3)
self.assertEquals(y.i32_thing, -3) self.assertEqual(y.i64_thing, -5)
self.assertEquals(y.i64_thing, -5)
def test_exception(self):
self.client.testException('Safe', callback=self.stop)
v = self.wait(timeout=1)
self.client.testException('Xception', callback=self.stop)
ex = self.wait(timeout=1)
if type(ex) == Xception:
self.assertEquals(ex.errorCode, 1001)
self.assertEquals(ex.message, 'Xception')
else:
self.fail("should have gotten exception")
def test_oneway(self): def test_oneway(self):
def return_from_send(): self.client.testOneway(0.5)
self.stop('done with send')
self.client.testOneway(0.5, callback=return_from_send)
self.assertEquals(self.wait(timeout=1), 'done with send')
start, end, seconds = self.wait(timeout=1) start, end, seconds = self.wait(timeout=1)
self.assertAlmostEquals(seconds, (end - start), places=3) self.assertAlmostEquals(seconds, (end - start), places=3)
@gen_test
def test_map(self):
"""
TestHandler.testMap is a coroutine, this test checks if gen.Return() from a coroutine works.
"""
expected = {1: 1}
res = yield self.client.testMap(expected)
self.assertEqual(res, expected)
@gen_test
def test_exception(self):
yield self.client.testException('Safe')
try:
yield self.client.testException('Xception')
except Xception as ex:
self.assertEqual(ex.errorCode, 1001)
self.assertEqual(ex.message, 'Xception')
else:
self.fail("should have gotten exception")
try:
yield self.client.testException('throw_undeclared')
except TTransportException as ex:
pass
else:
self.fail("should have gotten exception")
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()