mirror of
https://github.com/valitydev/salt.git
synced 2024-11-08 17:33:54 +00:00
Switch ReqServer over to IOLoop
This commit is contained in:
parent
b1bb0d62d2
commit
5bee5649db
@ -28,6 +28,9 @@ import salt.ext.six as six
|
||||
from salt.ext.six.moves import range
|
||||
# pylint: enable=import-error,no-name-in-module,redefined-builtin
|
||||
|
||||
import zmq.eventloop.ioloop
|
||||
import tornado.gen
|
||||
|
||||
# Import salt libs
|
||||
import salt.crypt
|
||||
import salt.utils
|
||||
@ -651,51 +654,10 @@ class MWorker(multiprocessing.Process):
|
||||
'''
|
||||
Bind to the local port
|
||||
'''
|
||||
self.req_channel.post_fork() # TODO: cleaner? Maybe lazily?
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
payload = self.req_channel.recv(None) # blocking get
|
||||
# TODO: maybe change into a wrapper class?
|
||||
# req_opts defines our response function
|
||||
try:
|
||||
ret, req_opts = self._handle_payload(payload)
|
||||
except KeyError:
|
||||
ret, req_opts = 'error', {}
|
||||
log.debug('Exception when handling payload', exc_info=True)
|
||||
|
||||
# req_fun: default to send
|
||||
req_fun = req_opts.get('fun', 'send')
|
||||
if req_fun == 'send_clear':
|
||||
self.req_channel.send_clear(ret)
|
||||
elif req_fun == 'send':
|
||||
self.req_channel.send(ret)
|
||||
elif req_fun == 'send_private':
|
||||
self.req_channel.send_private(ret, req_opts['key'], req_opts['tgt'])
|
||||
else:
|
||||
log.error('Unknown req_fun {0}'.format(req_fun))
|
||||
# don't catch keyboard interrupts, just re-raise them
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
# catch all other exceptions, so we don't go defunct
|
||||
except Exception as exc:
|
||||
# since we are in an exceptional state, lets attempt to tell
|
||||
# the minion we have a problem, otherwise the minion will get
|
||||
# no response and be forced to wait for their max timeout
|
||||
try:
|
||||
socket.send('Unexpected Error in Mworker')
|
||||
except: # pylint: disable=W0702
|
||||
pass
|
||||
# Properly handle EINTR from SIGUSR1
|
||||
if isinstance(exc, zmq.ZMQError) and exc.errno == errno.EINTR:
|
||||
continue
|
||||
log.critical('Unexpected Error in MWorker',
|
||||
exc_info=True)
|
||||
|
||||
# Changes here create a zeromq condition, check with thatch45 before
|
||||
# making any zeromq changes
|
||||
except KeyboardInterrupt:
|
||||
del self.req_channel
|
||||
# using ZMQIOLoop since we *might* need zmq in there
|
||||
io_loop = zmq.eventloop.ioloop.ZMQIOLoop()
|
||||
self.req_channel.post_fork(self._handle_payload, io_loop=io_loop) # TODO: cleaner? Maybe lazily?
|
||||
io_loop.start()
|
||||
|
||||
def _handle_payload(self, payload):
|
||||
'''
|
||||
|
@ -127,23 +127,11 @@ class AESReqServerMixin(object):
|
||||
# we need to decrypt it
|
||||
if payload['enc'] == 'aes':
|
||||
try:
|
||||
try:
|
||||
payload['load'] = self.crypticle.loads(payload['load'])
|
||||
except salt.crypt.AuthenticationError:
|
||||
if not self._update_aes():
|
||||
raise
|
||||
payload['load'] = self.crypticle.loads(payload['load'])
|
||||
except Exception:
|
||||
# send something back to the client so the client so they know
|
||||
# their load was malformed
|
||||
self.send('bad load')
|
||||
raise
|
||||
|
||||
# intercept the "_auth" commands, since the main daemon shouldn't know
|
||||
# anything about our key auth
|
||||
if payload['enc'] == 'clear' and payload['load']['cmd'] == '_auth':
|
||||
self.send_clear(self._auth(payload['load']))
|
||||
return None
|
||||
payload['load'] = self.crypticle.loads(payload['load'])
|
||||
except salt.crypt.AuthenticationError:
|
||||
if not self._update_aes():
|
||||
raise
|
||||
payload['load'] = self.crypticle.loads(payload['load'])
|
||||
return payload
|
||||
|
||||
def _auth(self, load):
|
||||
|
@ -48,47 +48,14 @@ class ReqServerChannel(object):
|
||||
'''
|
||||
pass
|
||||
|
||||
def post_fork(self):
|
||||
def post_fork(self, payload_handler, io_loop):
|
||||
'''
|
||||
Do anything you need post-fork. This should be something like recv
|
||||
Do anything you need post-fork. This should handle all incoming payloads
|
||||
and call payload_handler. You will also be passed io_loop, for all of your
|
||||
async needs
|
||||
'''
|
||||
pass
|
||||
|
||||
def recv(self, timeout=0):
|
||||
'''
|
||||
Get a req job, with an optional timeout
|
||||
0: nonblocking
|
||||
None: forever
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def socket(self):
|
||||
'''
|
||||
Return a socket (or fd) which can be used for poll mechanisms
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def send_clear(self, payload):
|
||||
'''
|
||||
Send a response to a recv()'d payload
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def send(self, payload):
|
||||
'''
|
||||
Send a response to a recv()'d payload
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def send_private(self, payload, dictkey, target):
|
||||
'''
|
||||
Send a response to a recv()'d payload encrypted privately for target
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PubServerChannel(object):
|
||||
'''
|
||||
Factory class to create subscription channels to the master's Publisher
|
||||
|
@ -43,6 +43,7 @@ import zmq.eventloop.zmqstream
|
||||
import tornado
|
||||
import tornado.tcpserver
|
||||
import tornado.gen
|
||||
import tornado.concurrent
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -229,7 +230,6 @@ class TCPPubChannel(salt.transport.mixins.auth.AESPubClientMixin, salt.transport
|
||||
class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.transport.server.ReqServerChannel):
|
||||
# TODO: opts!
|
||||
backlog = 5
|
||||
size = 16384
|
||||
|
||||
@property
|
||||
def socket(self):
|
||||
@ -244,83 +244,126 @@ class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.tra
|
||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self._socket.bind((self.opts['interface'], int(self.opts['ret_port'])))
|
||||
|
||||
def post_fork(self):
|
||||
def post_fork(self, payload_handler, io_loop):
|
||||
'''
|
||||
After forking we need to create all of the local sockets to listen to the
|
||||
router
|
||||
'''
|
||||
self.socket.listen(self.backlog)
|
||||
self.client = None # The client we are currently talking with
|
||||
|
||||
payload_handler: function to call with your payloads
|
||||
'''
|
||||
self.payload_handler = payload_handler
|
||||
self.io_loop = io_loop
|
||||
self.req_server = SaltMessageServer(self.handle_message, io_loop=io_loop)
|
||||
self.socket.listen(self.backlog)
|
||||
self.req_server.add_socket(self.socket)
|
||||
|
||||
self.serial = salt.payload.Serial(self.opts)
|
||||
salt.transport.mixins.auth.AESReqServerMixin.post_fork(self)
|
||||
|
||||
self.epoll = select.epoll()
|
||||
self.epoll.register(self.socket.fileno(), select.EPOLLIN)
|
||||
# map of fd -> (socket, address)
|
||||
self.sock_map = {}
|
||||
|
||||
def recv(self, timeout=0.001):
|
||||
@tornado.gen.coroutine
|
||||
def handle_message(self, stream, header, payload):
|
||||
'''
|
||||
Get a req job, with an optional timeout
|
||||
0: nonblocking
|
||||
None: forever
|
||||
|
||||
This is the main event loop of the TCP server. Since we aren't using an
|
||||
event driven mechanism-- we'll allow the caller to determine when we should
|
||||
listen next-- by controlling the timeout with which we are called
|
||||
Handle incoming messages from underylying tcp streams
|
||||
'''
|
||||
if timeout:
|
||||
timeout = timeout * 1000 # epoll takes milliseconds
|
||||
socks = self.epoll.poll(timeout)
|
||||
try:
|
||||
payload = self._decode_payload(payload)
|
||||
except Exception:
|
||||
stream.write(frame_msg(self.serial.dumps('bad load')))
|
||||
raise tornado.gen.Return()
|
||||
|
||||
# intercept the "_auth" commands, since the main daemon shouldn't know
|
||||
# anything about our key auth
|
||||
if payload['enc'] == 'clear' and payload['load']['cmd'] == '_auth':
|
||||
stream.write(frame_msg(self.serial.dumps(self._auth(payload['load']))))
|
||||
raise tornado.gen.Return()
|
||||
|
||||
# TODO: handle exceptions
|
||||
try:
|
||||
ret, req_opts = self.payload_handler(payload) # TODO: check if a future
|
||||
except Exception as e:
|
||||
log.error('Some exception handling a payload from minion', exc_info=True)
|
||||
stream.close()
|
||||
raise tornado.gen.Return()
|
||||
|
||||
req_fun = req_opts.get('fun', 'send')
|
||||
if req_fun == 'send_clear':
|
||||
ret['enc'] = 'clear'
|
||||
stream.write(frame_msg(self.serial.dumps(ret), header=header))
|
||||
elif req_fun == 'send':
|
||||
stream.write(frame_msg(self.serial.dumps(self.crypticle.dumps(ret)), header=header))
|
||||
elif req_fun == 'send_private':
|
||||
stream.write(frame_msg(self.serial.dumps(self._encrypt_private(ret,
|
||||
req_opts['key'],
|
||||
req_opts['tgt'],
|
||||
)),header=header))
|
||||
else:
|
||||
socks = self.epoll.poll()
|
||||
for fd, event in socks:
|
||||
# do we have a new client?
|
||||
if fd == self.socket.fileno():
|
||||
client, address = self.socket.accept()
|
||||
self.epoll.register(client.fileno(), select.EPOLLIN)
|
||||
self.sock_map[client.fileno()] = (client, address)
|
||||
log.trace('New client at {0} connected to reqserver'.format(address))
|
||||
if fd in self.sock_map:
|
||||
client, address = self.sock_map[fd]
|
||||
try:
|
||||
payload = socket_frame_recv(client)
|
||||
# if the client bombed out on us, we just soldier on
|
||||
except socket.error as e:
|
||||
log.trace('Socket error {0} communicating with {1}, closing connection'.format(e, address))
|
||||
client.close()
|
||||
self.epoll.unregister(fd)
|
||||
del self.sock_map[fd]
|
||||
continue
|
||||
self.client = client
|
||||
payload = self.serial.loads(payload)
|
||||
payload = self._decode_payload(payload)
|
||||
# if timeout was 0 and we got a None, we intercepted the job,
|
||||
# so just queue up another recv()
|
||||
if payload is None and timeout == None:
|
||||
return self.recv(timeout=timeout)
|
||||
return payload
|
||||
if timeout is None:
|
||||
return self.recv(timeout=timeout)
|
||||
return None
|
||||
log.error('Unknown req_fun {0}'.format(req_fun))
|
||||
stream.close()
|
||||
|
||||
|
||||
def _send(self, payload):
|
||||
'''
|
||||
Helper function to serialize and send payload
|
||||
'''
|
||||
print ('sending??')
|
||||
try:
|
||||
self.client.send(frame_msg(self.serial.dumps(payload)))
|
||||
self.client.write(frame_msg(self.serial.dumps(payload)))
|
||||
# if there was an error, close the socket out
|
||||
except socket.error as e:
|
||||
self.client.close()
|
||||
epoll.unregister(self.client.fileno())
|
||||
del self.sock_map[self.client.fileno()]
|
||||
raise salt.exceptions.SaltClientError(e)
|
||||
# always reset self.client
|
||||
finally:
|
||||
self.client = None
|
||||
print ('sent')
|
||||
|
||||
|
||||
class SaltMessageServer(tornado.tcpserver.TCPServer):
|
||||
'''
|
||||
Raw TCP server which will recieve all of the TCP streams and re-assemble
|
||||
messages that are sent through to us
|
||||
'''
|
||||
def __init__(self, message_handler, *args, **kwargs):
|
||||
super(SaltMessageServer, self).__init__(*args, **kwargs)
|
||||
|
||||
self.clients = []
|
||||
self.message_handler = message_handler
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
'''
|
||||
Handle incoming streams and add messages to the incoming queue
|
||||
'''
|
||||
print ('req client connected {0}'.format(address))
|
||||
log.trace('Req client {0} connected'.format(address))
|
||||
self.clients.append((stream, address))
|
||||
try:
|
||||
while True:
|
||||
header_len = yield stream.read_until(' ')
|
||||
header_raw = yield stream.read_bytes(int(header_len))
|
||||
header = msgpack.loads(header_raw)
|
||||
body_raw = yield stream.read_bytes(int(header['msgLen']))
|
||||
body = msgpack.loads(body_raw)
|
||||
self.message_handler(stream, header, body)
|
||||
|
||||
except tornado.iostream.StreamClosedError:
|
||||
self.clients.remove((stream, address))
|
||||
|
||||
def shutdown(self):
|
||||
'''
|
||||
Shutdown the whole server
|
||||
'''
|
||||
for item in self.clients:
|
||||
client, address = item
|
||||
client.close()
|
||||
self.clients.remove(item)
|
||||
|
||||
|
||||
|
||||
# TODO: subclass tcpclient
|
||||
#class SaltMessageClient()
|
||||
|
||||
|
||||
|
||||
class PubServer(tornado.tcpserver.TCPServer):
|
||||
'''
|
||||
|
Loading…
Reference in New Issue
Block a user