Switch ReqServer over to IOLoop

This commit is contained in:
Thomas Jackson 2015-03-07 09:22:59 -08:00
parent b1bb0d62d2
commit 5bee5649db
4 changed files with 113 additions and 153 deletions

View File

@ -28,6 +28,9 @@ import salt.ext.six as six
from salt.ext.six.moves import range
# pylint: enable=import-error,no-name-in-module,redefined-builtin
import zmq.eventloop.ioloop
import tornado.gen
# Import salt libs
import salt.crypt
import salt.utils
@ -651,51 +654,10 @@ class MWorker(multiprocessing.Process):
'''
Bind to the local port
'''
self.req_channel.post_fork() # TODO: cleaner? Maybe lazily?
try:
while True:
try:
payload = self.req_channel.recv(None) # blocking get
# TODO: maybe change into a wrapper class?
# req_opts defines our response function
try:
ret, req_opts = self._handle_payload(payload)
except KeyError:
ret, req_opts = 'error', {}
log.debug('Exception when handling payload', exc_info=True)
# req_fun: default to send
req_fun = req_opts.get('fun', 'send')
if req_fun == 'send_clear':
self.req_channel.send_clear(ret)
elif req_fun == 'send':
self.req_channel.send(ret)
elif req_fun == 'send_private':
self.req_channel.send_private(ret, req_opts['key'], req_opts['tgt'])
else:
log.error('Unknown req_fun {0}'.format(req_fun))
# don't catch keyboard interrupts, just re-raise them
except KeyboardInterrupt:
raise
# catch all other exceptions, so we don't go defunct
except Exception as exc:
# since we are in an exceptional state, lets attempt to tell
# the minion we have a problem, otherwise the minion will get
# no response and be forced to wait for their max timeout
try:
socket.send('Unexpected Error in Mworker')
except: # pylint: disable=W0702
pass
# Properly handle EINTR from SIGUSR1
if isinstance(exc, zmq.ZMQError) and exc.errno == errno.EINTR:
continue
log.critical('Unexpected Error in MWorker',
exc_info=True)
# Changes here create a zeromq condition, check with thatch45 before
# making any zeromq changes
except KeyboardInterrupt:
del self.req_channel
# using ZMQIOLoop since we *might* need zmq in there
io_loop = zmq.eventloop.ioloop.ZMQIOLoop()
self.req_channel.post_fork(self._handle_payload, io_loop=io_loop) # TODO: cleaner? Maybe lazily?
io_loop.start()
def _handle_payload(self, payload):
'''

View File

@ -127,23 +127,11 @@ class AESReqServerMixin(object):
# we need to decrypt it
if payload['enc'] == 'aes':
try:
try:
payload['load'] = self.crypticle.loads(payload['load'])
except salt.crypt.AuthenticationError:
if not self._update_aes():
raise
payload['load'] = self.crypticle.loads(payload['load'])
except Exception:
# send something back to the client so the client so they know
# their load was malformed
self.send('bad load')
raise
# intercept the "_auth" commands, since the main daemon shouldn't know
# anything about our key auth
if payload['enc'] == 'clear' and payload['load']['cmd'] == '_auth':
self.send_clear(self._auth(payload['load']))
return None
payload['load'] = self.crypticle.loads(payload['load'])
except salt.crypt.AuthenticationError:
if not self._update_aes():
raise
payload['load'] = self.crypticle.loads(payload['load'])
return payload
def _auth(self, load):

View File

@ -48,47 +48,14 @@ class ReqServerChannel(object):
'''
pass
def post_fork(self):
def post_fork(self, payload_handler, io_loop):
'''
Do anything you need post-fork. This should be something like recv
Do anything you need post-fork. This should handle all incoming payloads
and call payload_handler. You will also be passed io_loop, for all of your
async needs
'''
pass
def recv(self, timeout=0):
'''
Get a req job, with an optional timeout
0: nonblocking
None: forever
'''
raise NotImplementedError()
@property
def socket(self):
'''
Return a socket (or fd) which can be used for poll mechanisms
'''
raise NotImplementedError()
def send_clear(self, payload):
'''
Send a response to a recv()'d payload
'''
raise NotImplementedError()
def send(self, payload):
'''
Send a response to a recv()'d payload
'''
raise NotImplementedError()
def send_private(self, payload, dictkey, target):
'''
Send a response to a recv()'d payload encrypted privately for target
'''
raise NotImplementedError()
class PubServerChannel(object):
'''
Factory class to create subscription channels to the master's Publisher

View File

@ -43,6 +43,7 @@ import zmq.eventloop.zmqstream
import tornado
import tornado.tcpserver
import tornado.gen
import tornado.concurrent
log = logging.getLogger(__name__)
@ -229,7 +230,6 @@ class TCPPubChannel(salt.transport.mixins.auth.AESPubClientMixin, salt.transport
class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.transport.server.ReqServerChannel):
# TODO: opts!
backlog = 5
size = 16384
@property
def socket(self):
@ -244,83 +244,126 @@ class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.tra
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._socket.bind((self.opts['interface'], int(self.opts['ret_port'])))
def post_fork(self):
def post_fork(self, payload_handler, io_loop):
'''
After forking we need to create all of the local sockets to listen to the
router
'''
self.socket.listen(self.backlog)
self.client = None # The client we are currently talking with
payload_handler: function to call with your payloads
'''
self.payload_handler = payload_handler
self.io_loop = io_loop
self.req_server = SaltMessageServer(self.handle_message, io_loop=io_loop)
self.socket.listen(self.backlog)
self.req_server.add_socket(self.socket)
self.serial = salt.payload.Serial(self.opts)
salt.transport.mixins.auth.AESReqServerMixin.post_fork(self)
self.epoll = select.epoll()
self.epoll.register(self.socket.fileno(), select.EPOLLIN)
# map of fd -> (socket, address)
self.sock_map = {}
def recv(self, timeout=0.001):
@tornado.gen.coroutine
def handle_message(self, stream, header, payload):
'''
Get a req job, with an optional timeout
0: nonblocking
None: forever
This is the main event loop of the TCP server. Since we aren't using an
event driven mechanism-- we'll allow the caller to determine when we should
listen next-- by controlling the timeout with which we are called
Handle incoming messages from underylying tcp streams
'''
if timeout:
timeout = timeout * 1000 # epoll takes milliseconds
socks = self.epoll.poll(timeout)
try:
payload = self._decode_payload(payload)
except Exception:
stream.write(frame_msg(self.serial.dumps('bad load')))
raise tornado.gen.Return()
# intercept the "_auth" commands, since the main daemon shouldn't know
# anything about our key auth
if payload['enc'] == 'clear' and payload['load']['cmd'] == '_auth':
stream.write(frame_msg(self.serial.dumps(self._auth(payload['load']))))
raise tornado.gen.Return()
# TODO: handle exceptions
try:
ret, req_opts = self.payload_handler(payload) # TODO: check if a future
except Exception as e:
log.error('Some exception handling a payload from minion', exc_info=True)
stream.close()
raise tornado.gen.Return()
req_fun = req_opts.get('fun', 'send')
if req_fun == 'send_clear':
ret['enc'] = 'clear'
stream.write(frame_msg(self.serial.dumps(ret), header=header))
elif req_fun == 'send':
stream.write(frame_msg(self.serial.dumps(self.crypticle.dumps(ret)), header=header))
elif req_fun == 'send_private':
stream.write(frame_msg(self.serial.dumps(self._encrypt_private(ret,
req_opts['key'],
req_opts['tgt'],
)),header=header))
else:
socks = self.epoll.poll()
for fd, event in socks:
# do we have a new client?
if fd == self.socket.fileno():
client, address = self.socket.accept()
self.epoll.register(client.fileno(), select.EPOLLIN)
self.sock_map[client.fileno()] = (client, address)
log.trace('New client at {0} connected to reqserver'.format(address))
if fd in self.sock_map:
client, address = self.sock_map[fd]
try:
payload = socket_frame_recv(client)
# if the client bombed out on us, we just soldier on
except socket.error as e:
log.trace('Socket error {0} communicating with {1}, closing connection'.format(e, address))
client.close()
self.epoll.unregister(fd)
del self.sock_map[fd]
continue
self.client = client
payload = self.serial.loads(payload)
payload = self._decode_payload(payload)
# if timeout was 0 and we got a None, we intercepted the job,
# so just queue up another recv()
if payload is None and timeout == None:
return self.recv(timeout=timeout)
return payload
if timeout is None:
return self.recv(timeout=timeout)
return None
log.error('Unknown req_fun {0}'.format(req_fun))
stream.close()
def _send(self, payload):
'''
Helper function to serialize and send payload
'''
print ('sending??')
try:
self.client.send(frame_msg(self.serial.dumps(payload)))
self.client.write(frame_msg(self.serial.dumps(payload)))
# if there was an error, close the socket out
except socket.error as e:
self.client.close()
epoll.unregister(self.client.fileno())
del self.sock_map[self.client.fileno()]
raise salt.exceptions.SaltClientError(e)
# always reset self.client
finally:
self.client = None
print ('sent')
class SaltMessageServer(tornado.tcpserver.TCPServer):
'''
Raw TCP server which will recieve all of the TCP streams and re-assemble
messages that are sent through to us
'''
def __init__(self, message_handler, *args, **kwargs):
super(SaltMessageServer, self).__init__(*args, **kwargs)
self.clients = []
self.message_handler = message_handler
@tornado.gen.coroutine
def handle_stream(self, stream, address):
'''
Handle incoming streams and add messages to the incoming queue
'''
print ('req client connected {0}'.format(address))
log.trace('Req client {0} connected'.format(address))
self.clients.append((stream, address))
try:
while True:
header_len = yield stream.read_until(' ')
header_raw = yield stream.read_bytes(int(header_len))
header = msgpack.loads(header_raw)
body_raw = yield stream.read_bytes(int(header['msgLen']))
body = msgpack.loads(body_raw)
self.message_handler(stream, header, body)
except tornado.iostream.StreamClosedError:
self.clients.remove((stream, address))
def shutdown(self):
'''
Shutdown the whole server
'''
for item in self.clients:
client, address = item
client.close()
self.clients.remove(item)
# TODO: subclass tcpclient
#class SaltMessageClient()
class PubServer(tornado.tcpserver.TCPServer):
'''