mirror of
https://github.com/valitydev/salt.git
synced 2024-11-07 08:58:59 +00:00
IPC transport skeleton
Basic IPC server works! Lint Skeleton of client IPC bind test Make stand-alone Adding factories for push and pull channels Allowing opts passing for consistency Tests now (mostly) work Lint Method documentation General cleanup. Migrate to inheritence. Log cleanup Migrate framing to stand-along module Migrate ipc.py to new framer Working except for serialization bug Debugging Debugging It works!! Remove ZeroMQ from TCP transport :] General cleanup Linting General cleanup Align socket name with what client expects Remove unused buffer size flag exception handling for stream close Calls to parent class inits Docs Remove debugging Remove unused function Remove unnecessary pre/post fork on msgclient Remove unecessary timeout flag Better stream/socket shutdown in server Remove unused handler Removing more unused More function cleanup Removing more unneeded cruft Lint Round out documentation More docs Misc hacks to fix up @cachedout's IPC This was using a mix of blocking and non-blocking calls, which was making a bit of a mess. connect and write are both non-blocking calls on IOStreams, so we either need to handle all the callbacks or do them in the coroutine fashion (much easier to manage). This meant that in the tests your "write" wouldn't make it out since we didn't wait on the connect. IMO we should refactor this IPC stuff to have proper async interfaces and wrap if absolutely necessary, but I think its reasonable to ask that as part of this we make some more of the core coroutines :) for #23236 Lint Remove init of io_loop because we require start() Various fixes Remove uneeded functionality Remove dup Cleanup and remove unused functions Moving toward coroutines More lint handle_connection changed to spawn Singletons for ipcclient Lint disable Remove redundent check in close() Remove duplicates in init Improved exception handling Test framework Require sock path to be passed in Better testing approach Remove unecessary __init__ Misc cleanup of unecessary methods Major rework of the IPC channels to make them work :) Remove TODO, since the feature was implemented Add more tests for IPC Add support for reconnecting clients, as well as a return from the IPCServer misc cleanup Lint test case Lint transport
This commit is contained in:
parent
3563894d03
commit
fcd9197f86
@ -36,6 +36,32 @@ class ReqChannel(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PushChannel(object):
|
||||
'''
|
||||
Factory class to create Sync channel for push side of push/pull IPC
|
||||
'''
|
||||
@staticmethod
|
||||
def factory(opts, **kwargs):
|
||||
sync = SyncWrapper(AsyncPushChannel.factory, (opts,), kwargs)
|
||||
return sync
|
||||
|
||||
def send(self, load, tries=3, timeout=60):
|
||||
'''
|
||||
Send load across IPC push
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PullChannel(object):
|
||||
'''
|
||||
Factory class to create Sync channel for pull side of push/pull IPC
|
||||
'''
|
||||
@staticmethod
|
||||
def factory(opts, **kwargs):
|
||||
sync = SyncWrapper(AsyncPullChannel.factory, (opts,), kwargs)
|
||||
return sync
|
||||
|
||||
|
||||
# TODO: better doc strings
|
||||
class AsyncChannel(object):
|
||||
'''
|
||||
@ -150,4 +176,34 @@ class AsyncPubChannel(AsyncChannel):
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AsyncPushChannel(object):
|
||||
'''
|
||||
Factory class to create IPC Push channels
|
||||
'''
|
||||
@staticmethod
|
||||
def factory(opts, **kwargs):
|
||||
'''
|
||||
If we have additional IPC transports other than UxD and TCP, add them here
|
||||
'''
|
||||
# FIXME for now, just UXD
|
||||
# Obviously, this makes the factory approach pointless, but we'll extend later
|
||||
import salt.transport.ipc
|
||||
return salt.transport.ipc.IPCMessageClient(opts, **kwargs)
|
||||
|
||||
|
||||
class AsyncPullChannel(object):
|
||||
'''
|
||||
Factory class to create IPC pull channels
|
||||
'''
|
||||
@staticmethod
|
||||
def factory(opts, **kwargs):
|
||||
'''
|
||||
If we have additional IPC transports other than UXD and TCP, add them here
|
||||
'''
|
||||
import salt.transport.ipc
|
||||
return salt.transport.ipc.IPCMessageServer(opts, **kwargs)
|
||||
|
||||
## Additional IPC messaging patterns should provide interfaces here, ala router/dealer, pub/sub, etc
|
||||
|
||||
# EOF
|
||||
|
25
salt/transport/frame.py
Normal file
25
salt/transport/frame.py
Normal file
@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Helper functions for transport components to handle message framing
|
||||
'''
|
||||
# Import python libs
|
||||
from __future__ import absolute_import
|
||||
import msgpack
|
||||
|
||||
|
||||
def frame_msg(body, header=None, raw_body=False):
|
||||
'''
|
||||
Frame the given message with our wire protocol
|
||||
'''
|
||||
framed_msg = {}
|
||||
if header is None:
|
||||
header = {}
|
||||
|
||||
# if the body wasn't already msgpacked-- lets do that.
|
||||
if not raw_body:
|
||||
body = msgpack.dumps(body)
|
||||
|
||||
framed_msg['head'] = header
|
||||
framed_msg['body'] = body
|
||||
framed_msg_packed = msgpack.dumps(framed_msg)
|
||||
return '{0} {1}'.format(len(framed_msg_packed), framed_msg_packed)
|
325
salt/transport/ipc.py
Normal file
325
salt/transport/ipc.py
Normal file
@ -0,0 +1,325 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
IPC transport classes
|
||||
'''
|
||||
|
||||
# Import Python libs
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
import socket
|
||||
import msgpack
|
||||
import weakref
|
||||
|
||||
# Import Tornado libs
|
||||
import tornado
|
||||
import tornado.gen
|
||||
import tornado.netutil
|
||||
import tornado.concurrent
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream
|
||||
|
||||
# Import Salt libs
|
||||
import salt.transport.client
|
||||
import salt.transport.frame
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IPCServer(object):
|
||||
'''
|
||||
A Tornado IPC server very similar to Tornado's TCPServer class
|
||||
but using either UNIX domain sockets or TCP sockets
|
||||
'''
|
||||
def __init__(self, socket_path, io_loop=None, payload_handler=None):
|
||||
'''
|
||||
Create a new Tornado IPC server
|
||||
|
||||
:param IOLoop io_loop: A Tornado ioloop to handle scheduling
|
||||
:param func stream_handler: A function to customize handling of an
|
||||
incoming stream.
|
||||
'''
|
||||
self.socket_path = socket_path
|
||||
self._started = False
|
||||
self.payload_handler = payload_handler
|
||||
|
||||
# Placeholders for attributes to be populated by method calls
|
||||
self.stream = None
|
||||
self.sock = None
|
||||
self.io_loop = io_loop or IOLoop.current()
|
||||
|
||||
def start(self):
|
||||
'''
|
||||
Perform the work necessary to start up a Tornado IPC server
|
||||
|
||||
Blocks until socket is established
|
||||
|
||||
:param str socket_path: Path on the filesystem for the socket to bind to.
|
||||
This socket does not need to exist prior to calling
|
||||
this method, but parent directories should.
|
||||
'''
|
||||
# Start up the ioloop
|
||||
log.trace('IPCServer: binding to socket: {0}'.format(self.socket_path))
|
||||
self.sock = tornado.netutil.bind_unix_socket(self.socket_path)
|
||||
|
||||
tornado.netutil.add_accept_handler(
|
||||
self.sock,
|
||||
self.handle_connection,
|
||||
io_loop=self.io_loop,
|
||||
)
|
||||
self._started = True
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def handle_stream(self, stream):
|
||||
'''
|
||||
Override this to handle the streams as they arrive
|
||||
|
||||
:param IOStream stream: An IOStream for processing
|
||||
|
||||
See http://tornado.readthedocs.org/en/latest/iostream.html#tornado.iostream.IOStream
|
||||
for additional details.
|
||||
'''
|
||||
@tornado.gen.coroutine
|
||||
def _null(msg):
|
||||
raise tornado.gen.Return(None)
|
||||
|
||||
def write_callback(stream, header):
|
||||
if header.get('mid'):
|
||||
@tornado.gen.coroutine
|
||||
def return_message(msg):
|
||||
pack = salt.transport.frame.frame_msg(
|
||||
msg,
|
||||
header={'mid': header['mid']},
|
||||
raw_body=True,
|
||||
)
|
||||
yield stream.write(pack)
|
||||
return return_message
|
||||
else:
|
||||
return _null
|
||||
while not stream.closed():
|
||||
try:
|
||||
framed_msg_len = yield stream.read_until(' ')
|
||||
framed_msg_raw = yield stream.read_bytes(int(framed_msg_len.strip()))
|
||||
framed_msg = msgpack.loads(framed_msg_raw)
|
||||
body = framed_msg['body']
|
||||
self.io_loop.spawn_callback(self.payload_handler, body, write_callback(stream, framed_msg['head']))
|
||||
except Exception as exc:
|
||||
log.error('Exception occurred while handling stream: {0}'.format(exc))
|
||||
|
||||
def handle_connection(self, connection, address):
|
||||
log.trace('IPCServer: Handling connection to address: {0}'.format(address))
|
||||
try:
|
||||
stream = IOStream(
|
||||
connection,
|
||||
io_loop=self.io_loop,
|
||||
)
|
||||
self.io_loop.spawn_callback(self.handle_stream, stream)
|
||||
except Exception as exc:
|
||||
log.error('IPC streaming error: {0}'.format(exc))
|
||||
|
||||
def close(self):
|
||||
'''
|
||||
Routines to handle any cleanup before the instance shuts down.
|
||||
Sockets and filehandles should be closed explicitely, to prevent
|
||||
leaks.
|
||||
'''
|
||||
if hasattr(self.stream, 'close'):
|
||||
self.stream.close()
|
||||
if hasattr(self.sock, 'close'):
|
||||
self.sock.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
class IPCClient(object):
|
||||
'''
|
||||
A Tornado IPC client very similar to Tornado's TCPClient class
|
||||
but using either UNIX domain sockets or TCP sockets
|
||||
|
||||
This was written because Tornado does not have its own IPC
|
||||
server/client implementation.
|
||||
|
||||
:param IOLoop io_loop: A Tornado ioloop to handle scheduling
|
||||
:param str socket_path: A path on the filesystem where a socket
|
||||
belonging to a running IPCServer can be
|
||||
found.
|
||||
'''
|
||||
|
||||
# Create singleton map between two sockets
|
||||
instance_map = weakref.WeakKeyDictionary()
|
||||
|
||||
def __new__(cls, socket_path, io_loop=None):
|
||||
io_loop = io_loop or tornado.ioloop.IOLoop.current()
|
||||
if io_loop not in IPCClient.instance_map:
|
||||
IPCClient.instance_map[io_loop] = weakref.WeakValueDictionary()
|
||||
loop_instance_map = IPCClient.instance_map[io_loop]
|
||||
|
||||
# FIXME
|
||||
key = socket_path
|
||||
|
||||
if key not in loop_instance_map:
|
||||
log.debug('Initializing new IPCClient for path: {0}'.format(key))
|
||||
new_client = object.__new__(cls)
|
||||
# FIXME
|
||||
new_client.__singleton_init__(io_loop=io_loop, socket_path=socket_path)
|
||||
loop_instance_map[key] = new_client
|
||||
else:
|
||||
log.debug('Re-using IPCClient for {0}'.format(key))
|
||||
return loop_instance_map[key]
|
||||
|
||||
def __singleton_init__(self, socket_path, io_loop=None):
|
||||
'''
|
||||
Create a new IPC client
|
||||
|
||||
IPC clients cannot bind to ports, but must connect to
|
||||
existing IPC servers. Clients can then send messages
|
||||
to the server.
|
||||
|
||||
'''
|
||||
self.io_loop = io_loop or tornado.ioloop.IOLoop.current()
|
||||
self.socket_path = socket_path
|
||||
self._closing = False
|
||||
|
||||
def __init__(self, socket_path, io_loop=None):
|
||||
# Handled by singleton __new__
|
||||
pass
|
||||
|
||||
def connected(self):
|
||||
return hasattr(self, 'stream')
|
||||
|
||||
def connect(self, callback=None):
|
||||
'''
|
||||
Connect to the IPC socket
|
||||
'''
|
||||
if hasattr(self, '_connecting_future') and not self._connecting_future.done(): # pylint: disable=E0203
|
||||
future = self._connecting_future # pylint: disable=E0203
|
||||
else:
|
||||
future = tornado.concurrent.Future()
|
||||
self._connecting_future = future
|
||||
self.io_loop.add_callback(self._connect)
|
||||
|
||||
if callback is not None:
|
||||
def handle_future(future):
|
||||
response = future.result()
|
||||
self.io_loop.add_callback(callback, response)
|
||||
future.add_done_callback(handle_future)
|
||||
return future
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def _connect(self):
|
||||
'''
|
||||
Connect to a running IPCServer
|
||||
'''
|
||||
self.stream = IOStream(
|
||||
socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
|
||||
io_loop=self.io_loop,
|
||||
)
|
||||
while True:
|
||||
if self._closing:
|
||||
break
|
||||
try:
|
||||
log.trace('IPCClient: Connecting to socket: {0}'.format(self.socket_path))
|
||||
yield self.stream.connect(self.socket_path)
|
||||
self._connecting_future.set_result(True)
|
||||
break
|
||||
except Exception as e:
|
||||
yield tornado.gen.sleep(1) # TODO: backoff
|
||||
#self._connecting_future.set_exception(e)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
'''
|
||||
Routines to handle any cleanup before the instance shuts down.
|
||||
Sockets and filehandles should be closed explicitely, to prevent
|
||||
leaks.
|
||||
'''
|
||||
self._closing = True
|
||||
if hasattr(self, 'stream'):
|
||||
self.stream.close()
|
||||
|
||||
|
||||
class IPCMessageClient(IPCClient):
|
||||
'''
|
||||
Salt IPC message client
|
||||
|
||||
Create an IPC client to send messages to an IPC server
|
||||
|
||||
An example of a very simple IPCMessageClient connecting to an IPCServer. This
|
||||
example assumes an already running IPCMessage server.
|
||||
|
||||
IMPORTANT: The below example also assumes a running IOLoop process.
|
||||
|
||||
# Import Tornado libs
|
||||
import tornado.ioloop
|
||||
|
||||
# Import Salt libs
|
||||
import salt.config
|
||||
import salt.transport.ipc
|
||||
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
|
||||
ipc_server_socket_path = '/var/run/ipc_server.ipc'
|
||||
|
||||
ipc_client = salt.transport.ipc.IPCMessageClient(ipc_server_socket_path, io_loop=io_loop)
|
||||
|
||||
# Connect to the server
|
||||
ipc_client.connect()
|
||||
|
||||
# Send some data
|
||||
ipc_client.send('Hello world')
|
||||
'''
|
||||
# FIXME timeout unimplemented
|
||||
# FIXME tries unimplemented
|
||||
@tornado.gen.coroutine
|
||||
def send(self, msg, timeout=None, tries=None):
|
||||
'''
|
||||
Send a message to an IPC socket
|
||||
|
||||
If the socket is not currently connected, a connection will be established.
|
||||
|
||||
:param dict msg: The message to be sent
|
||||
:param int timeout: Timeout when sending message (Currently unimplemented)
|
||||
'''
|
||||
if not self.connected():
|
||||
yield self.connect()
|
||||
pack = salt.transport.frame.frame_msg(msg, raw_body=True)
|
||||
yield self.stream.write(pack)
|
||||
|
||||
|
||||
class IPCMessageServer(IPCServer):
|
||||
'''
|
||||
Salt IPC message server
|
||||
|
||||
Creates a message server which can create and bind to a socket on a given
|
||||
path and then respond to messages asynchronously.
|
||||
|
||||
An example of a very simple IPCServer which prints received messages to
|
||||
a console:
|
||||
|
||||
# Import Tornado libs
|
||||
import tornado.ioloop
|
||||
|
||||
# Import Salt libs
|
||||
import salt.transport.ipc
|
||||
import salt.config
|
||||
|
||||
opts = salt.config.master_opts()
|
||||
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
ipc_server_socket_path = '/var/run/ipc_server.ipc'
|
||||
ipc_server = salt.transport.ipc.IPCMessageServer(opts, io_loop=io_loop
|
||||
stream_handler=print_to_console)
|
||||
# Bind to the socket and prepare to run
|
||||
ipc_server.start(ipc_server_socket_path)
|
||||
|
||||
# Start the server
|
||||
io_loop.start()
|
||||
|
||||
# This callback is run whenever a message is received
|
||||
def print_to_console(payload):
|
||||
print(payload)
|
||||
|
||||
See IPCMessageClient() for an example of sending messages to an IPCMessageServer instance
|
||||
'''
|
@ -38,6 +38,7 @@ class AESPubClientMixin(object):
|
||||
@tornado.gen.coroutine
|
||||
def _decode_payload(self, payload):
|
||||
# we need to decrypt it
|
||||
log.trace('Decoding payload: {0}'.format(payload))
|
||||
if payload['enc'] == 'aes':
|
||||
self._verify_master_signature(payload)
|
||||
try:
|
||||
|
@ -22,21 +22,16 @@ import salt.crypt
|
||||
import salt.utils
|
||||
import salt.utils.verify
|
||||
import salt.utils.event
|
||||
import salt.utils.async
|
||||
import salt.payload
|
||||
import salt.exceptions
|
||||
import salt.transport.frame
|
||||
import salt.transport.ipc
|
||||
import salt.transport.client
|
||||
import salt.transport.server
|
||||
import salt.transport.mixins.auth
|
||||
from salt.exceptions import SaltReqTimeoutError, SaltClientError
|
||||
|
||||
# for IPC (for now)
|
||||
import zmq
|
||||
import zmq.eventloop.ioloop
|
||||
# support pyzmq 13.0.x, TODO: remove once we force people to 14.0.x
|
||||
if not hasattr(zmq.eventloop.ioloop, 'ZMQIOLoop'):
|
||||
zmq.eventloop.ioloop.ZMQIOLoop = zmq.eventloop.ioloop.IOLoop
|
||||
import zmq.eventloop.zmqstream
|
||||
|
||||
# Import Tornado Libs
|
||||
import tornado
|
||||
import tornado.tcpserver
|
||||
@ -51,59 +46,6 @@ from Crypto.Cipher import PKCS1_OAEP
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def frame_msg(body, header=None, raw_body=False):
|
||||
'''
|
||||
Frame the given message with our wire protocol
|
||||
'''
|
||||
framed_msg = {}
|
||||
if header is None:
|
||||
header = {}
|
||||
|
||||
# if the body wasn't already msgpacked-- lets do that.
|
||||
if not raw_body:
|
||||
body = msgpack.dumps(body)
|
||||
|
||||
framed_msg['head'] = header
|
||||
framed_msg['body'] = body
|
||||
framed_msg_packed = msgpack.dumps(framed_msg)
|
||||
return '{0} {1}'.format(len(framed_msg_packed), framed_msg_packed)
|
||||
|
||||
|
||||
def socket_frame_recv(s, recv_size=4096):
|
||||
'''
|
||||
Retrieve a frame from socket
|
||||
'''
|
||||
# get the header size
|
||||
recv_buf = ''
|
||||
while ' ' not in recv_buf:
|
||||
data = s.recv(recv_size)
|
||||
if data == '':
|
||||
raise socket.error('Empty response!')
|
||||
else:
|
||||
recv_buf += data
|
||||
# once we have a space, we know how long the rest is
|
||||
header_len, buf = recv_buf.split(' ', 1)
|
||||
header_len = int(header_len)
|
||||
while len(buf) < header_len:
|
||||
data = s.recv(recv_size)
|
||||
if data == '':
|
||||
raise socket.error('msg stopped, we are missing some data!')
|
||||
else:
|
||||
buf += data
|
||||
|
||||
header = msgpack.loads(buf[:header_len])
|
||||
msg_len = int(header['msgLen'])
|
||||
buf = buf[header_len:]
|
||||
while len(buf) < msg_len:
|
||||
data = s.recv(recv_size)
|
||||
if data == '':
|
||||
raise socket.error('msg stopped, we are missing some data!')
|
||||
else:
|
||||
buf += data
|
||||
|
||||
return buf
|
||||
|
||||
|
||||
# TODO: move serial down into message library
|
||||
class AsyncTCPReqChannel(salt.transport.client.ReqChannel):
|
||||
'''
|
||||
@ -335,18 +277,20 @@ class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.tra
|
||||
try:
|
||||
payload = self._decode_payload(payload)
|
||||
except Exception:
|
||||
stream.write(frame_msg('bad load', header=header))
|
||||
stream.write(salt.transport.frame.frame_msg('bad load', header=header))
|
||||
raise tornado.gen.Return()
|
||||
|
||||
# TODO helper functions to normalize payload?
|
||||
if not isinstance(payload, dict) or not isinstance(payload.get('load'), dict):
|
||||
yield stream.write(frame_msg('payload and load must be a dict', header=header))
|
||||
yield stream.write(salt.transport.frame.frame_msg(
|
||||
'payload and load must be a dict', header=header))
|
||||
raise tornado.gen.Return()
|
||||
|
||||
# intercept the "_auth" commands, since the main daemon shouldn't know
|
||||
# anything about our key auth
|
||||
if payload['enc'] == 'clear' and payload.get('load', {}).get('cmd') == '_auth':
|
||||
yield stream.write(frame_msg(self._auth(payload['load']), header=header))
|
||||
yield stream.write(salt.transport.frame.frame_msg(
|
||||
self._auth(payload['load']), header=header))
|
||||
raise tornado.gen.Return()
|
||||
|
||||
# TODO: test
|
||||
@ -361,11 +305,11 @@ class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.tra
|
||||
|
||||
req_fun = req_opts.get('fun', 'send')
|
||||
if req_fun == 'send_clear':
|
||||
stream.write(frame_msg(ret, header=header))
|
||||
stream.write(salt.transport.frame.frame_msg(ret, header=header))
|
||||
elif req_fun == 'send':
|
||||
stream.write(frame_msg(self.crypticle.dumps(ret), header=header))
|
||||
stream.write(salt.transport.frame.frame_msg(self.crypticle.dumps(ret), header=header))
|
||||
elif req_fun == 'send_private':
|
||||
stream.write(frame_msg(self._encrypt_private(ret,
|
||||
stream.write(salt.transport.frame.frame_msg(self._encrypt_private(ret,
|
||||
req_opts['key'],
|
||||
req_opts['tgt'],
|
||||
), header=header))
|
||||
@ -422,6 +366,8 @@ class SaltMessageServer(tornado.tcpserver.TCPServer, object):
|
||||
self.clients.remove(item)
|
||||
|
||||
|
||||
# TODO consolidate with IPCClient
|
||||
# TODO: limit in-flight messages.
|
||||
# TODO: singleton? Something to not re-create the tcp connection so much
|
||||
class SaltMessageClient(object):
|
||||
'''
|
||||
@ -609,7 +555,7 @@ class SaltMessageClient(object):
|
||||
# if we don't have a send queue, we need to spawn the callback to do the sending
|
||||
if len(self.send_queue) == 0:
|
||||
self.io_loop.spawn_callback(self._stream_send)
|
||||
self.send_queue.append((message_id, frame_msg(msg, header=header)))
|
||||
self.send_queue.append((message_id, salt.transport.frame.frame_msg(msg, header=header)))
|
||||
return future
|
||||
|
||||
|
||||
@ -625,15 +571,16 @@ class PubServer(tornado.tcpserver.TCPServer, object):
|
||||
log.trace('Subscriber at {0} connected'.format(address))
|
||||
self.clients.append((stream, address))
|
||||
|
||||
# TODO: ACK the publish through IPC
|
||||
@tornado.gen.coroutine
|
||||
def publish_payload(self, package):
|
||||
log.trace('TCP PubServer starting to publish payload')
|
||||
package = package[0] # ZMQ (The IPC calling us) ism :/
|
||||
payload = frame_msg(salt.payload.unpackage(package)['payload'], raw_body=True)
|
||||
def publish_payload(self, payload, _):
|
||||
payload = salt.transport.frame.frame_msg(payload['payload'], raw_body=True)
|
||||
log.debug('TCP PubServer sending payload: {0}'.format(payload))
|
||||
to_remove = []
|
||||
for item in self.clients:
|
||||
client, address = item
|
||||
try:
|
||||
# Write the packed str
|
||||
f = client.write(payload)
|
||||
self.io_loop.add_future(f, lambda f: True)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
@ -647,9 +594,10 @@ class PubServer(tornado.tcpserver.TCPServer, object):
|
||||
|
||||
|
||||
class TCPPubServerChannel(salt.transport.server.PubServerChannel):
|
||||
def __init__(self, opts):
|
||||
def __init__(self, opts, io_loop=None):
|
||||
self.opts = opts
|
||||
self.serial = salt.payload.Serial(self.opts) # TODO: in init?
|
||||
self.io_loop = io_loop or tornado.ioloop.IOLoop.current()
|
||||
|
||||
def _publish_daemon(self):
|
||||
'''
|
||||
@ -657,35 +605,28 @@ class TCPPubServerChannel(salt.transport.server.PubServerChannel):
|
||||
'''
|
||||
salt.utils.appendproctitle(self.__class__.__name__)
|
||||
|
||||
# Set up the context
|
||||
context = zmq.Context(1)
|
||||
# Prepare minion pull socket
|
||||
pull_sock = context.socket(zmq.PULL)
|
||||
pull_uri = 'ipc://{0}'.format(
|
||||
os.path.join(self.opts['sock_dir'], 'publish_pull_tcp.ipc')
|
||||
# Spin up the publisher
|
||||
pub_server = PubServer(io_loop=self.io_loop)
|
||||
pub_server.listen(int(self.opts['publish_port']), address=self.opts['interface'])
|
||||
|
||||
# Set up Salt IPC server
|
||||
pull_uri = os.path.join(self.opts['sock_dir'], 'publish_pull.ipc')
|
||||
pull_sock = salt.transport.ipc.IPCMessageServer(
|
||||
pull_uri,
|
||||
io_loop=self.io_loop,
|
||||
payload_handler=pub_server.publish_payload,
|
||||
)
|
||||
salt.utils.zeromq.check_ipc_path_max_len(pull_uri)
|
||||
|
||||
# Securely create socket
|
||||
log.info('Starting the Salt Puller on {0}'.format(pull_uri))
|
||||
old_umask = os.umask(0o177)
|
||||
try:
|
||||
pull_sock.bind(pull_uri)
|
||||
pull_sock.start()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
# load up the IOLoop
|
||||
io_loop = zmq.eventloop.ioloop.ZMQIOLoop()
|
||||
# add the publisher
|
||||
pub_server = PubServer(io_loop=io_loop)
|
||||
pub_server.listen(int(self.opts['publish_port']), address=self.opts['interface'])
|
||||
|
||||
# add our IPC
|
||||
stream = zmq.eventloop.zmqstream.ZMQStream(pull_sock, io_loop=io_loop)
|
||||
stream.on_recv(pub_server.publish_payload)
|
||||
|
||||
# run forever
|
||||
io_loop.start()
|
||||
self.io_loop.start()
|
||||
|
||||
def pre_fork(self, process_manager):
|
||||
'''
|
||||
@ -707,17 +648,20 @@ class TCPPubServerChannel(salt.transport.server.PubServerChannel):
|
||||
master_pem_path = os.path.join(self.opts['pki_dir'], 'master.pem')
|
||||
log.debug("Signing data packet")
|
||||
payload['sig'] = salt.crypt.sign_message(master_pem_path, payload['load'])
|
||||
# Send 0MQ to the publisher
|
||||
context = zmq.Context(1)
|
||||
pub_sock = context.socket(zmq.PUSH)
|
||||
pull_uri = 'ipc://{0}'.format(
|
||||
os.path.join(self.opts['sock_dir'], 'publish_pull_tcp.ipc')
|
||||
)
|
||||
pub_sock.connect(pull_uri)
|
||||
# Use the Salt IPC server
|
||||
pull_uri = os.path.join(self.opts['sock_dir'], 'publish_pull.ipc')
|
||||
# TODO: switch to the actual async interface
|
||||
#pub_sock = salt.transport.ipc.IPCMessageClient(self.opts, io_loop=self.io_loop)
|
||||
pub_sock = salt.utils.async.SyncWrapper(
|
||||
salt.transport.ipc.IPCMessageClient,
|
||||
(pull_uri,)
|
||||
)
|
||||
pub_sock.connect()
|
||||
|
||||
int_payload = {'payload': self.serial.dumps(payload)}
|
||||
|
||||
# add some targeting stuff for lists only (for now)
|
||||
if load['tgt_type'] == 'list':
|
||||
int_payload['topic_lst'] = load['tgt']
|
||||
|
||||
pub_sock.send(self.serial.dumps(int_payload))
|
||||
# Send it over IPC!
|
||||
pub_sock.send(int_payload)
|
||||
|
146
tests/unit/transport/ipc_test.py
Normal file
146
tests/unit/transport/ipc_test.py
Normal file
@ -0,0 +1,146 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
:codeauthor: :email:`Mike Place <mp@saltstack.com>`
|
||||
'''
|
||||
|
||||
# Import python libs
|
||||
from __future__ import absolute_import
|
||||
import os
|
||||
import logging
|
||||
|
||||
import tornado.gen
|
||||
import tornado.ioloop
|
||||
import tornado.testing
|
||||
|
||||
import salt.utils
|
||||
import salt.config
|
||||
import salt.exceptions
|
||||
import salt.transport.ipc
|
||||
import salt.transport.server
|
||||
import salt.transport.client
|
||||
|
||||
from salt.ext.six.moves import range
|
||||
|
||||
# Import Salt Testing libs
|
||||
import integration
|
||||
|
||||
from salttesting.mock import MagicMock
|
||||
from salttesting.helpers import ensure_in_syspath
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ensure_in_syspath('../')
|
||||
|
||||
|
||||
class BaseIPCReqCase(tornado.testing.AsyncTestCase):
|
||||
'''
|
||||
Test the req server/client pair
|
||||
'''
|
||||
def setUp(self):
|
||||
super(BaseIPCReqCase, self).setUp()
|
||||
self._start_handlers = dict(self.io_loop._handlers)
|
||||
self.socket_path = os.path.join(integration.TMP, 'ipc_test.ipc')
|
||||
|
||||
self.server_channel = salt.transport.ipc.IPCMessageServer(
|
||||
self.socket_path,
|
||||
io_loop=self.io_loop,
|
||||
payload_handler=self._handle_payload,
|
||||
)
|
||||
self.server_channel.start()
|
||||
|
||||
self.payloads = []
|
||||
|
||||
def tearDown(self):
|
||||
super(BaseIPCReqCase, self).tearDown()
|
||||
failures = []
|
||||
self.server_channel.close()
|
||||
os.unlink(self.socket_path)
|
||||
for k, v in self.io_loop._handlers.iteritems():
|
||||
if self._start_handlers.get(k) != v:
|
||||
failures.append((k, v))
|
||||
if len(failures) > 0:
|
||||
raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def _handle_payload(self, payload, reply_func):
|
||||
self.payloads.append(payload)
|
||||
yield reply_func(payload)
|
||||
if isinstance(payload, dict) and payload.get('stop'):
|
||||
self.stop()
|
||||
|
||||
|
||||
class IPCMessageClient(BaseIPCReqCase):
|
||||
'''
|
||||
Test all of the clear msg stuff
|
||||
'''
|
||||
|
||||
def _get_channel(self):
|
||||
channel = salt.transport.ipc.IPCMessageClient(
|
||||
socket_path=self.socket_path,
|
||||
io_loop=self.io_loop,
|
||||
)
|
||||
channel.connect(callback=self.stop)
|
||||
self.wait()
|
||||
return channel
|
||||
|
||||
def setUp(self):
|
||||
super(IPCMessageClient, self).setUp()
|
||||
self.channel = self._get_channel()
|
||||
|
||||
def tearDown(self):
|
||||
super(IPCMessageClient, self).setUp()
|
||||
self.channel.close()
|
||||
|
||||
def test_basic_send(self):
|
||||
msg = {'foo': 'bar', 'stop': True}
|
||||
self.channel.send(msg)
|
||||
self.wait()
|
||||
self.assertEqual(self.payloads[0], msg)
|
||||
|
||||
def test_many_send(self):
|
||||
msgs = []
|
||||
self.server_channel.stream_handler = MagicMock()
|
||||
|
||||
for i in range(0, 1000):
|
||||
msgs.append('test_many_send_{0}'.format(i))
|
||||
|
||||
for i in msgs:
|
||||
self.channel.send(i)
|
||||
self.channel.send({'stop': True})
|
||||
self.wait()
|
||||
self.assertEqual(self.payloads[:-1], msgs)
|
||||
|
||||
def test_very_big_message(self):
|
||||
long_str = ''.join([str(num) for num in range(10**5)])
|
||||
msg = {'long_str': long_str, 'stop': True}
|
||||
self.channel.send(msg)
|
||||
self.wait()
|
||||
self.assertEqual(msg, self.payloads[0])
|
||||
|
||||
def test_multistream_sends(self):
|
||||
local_channel = self._get_channel()
|
||||
|
||||
for c in (self.channel, local_channel):
|
||||
c.send('foo')
|
||||
|
||||
self.channel.send({'stop': True})
|
||||
self.wait()
|
||||
self.assertEqual(self.payloads[:-1], ['foo', 'foo'])
|
||||
|
||||
def test_multistream_errors(self):
|
||||
local_channel = self._get_channel()
|
||||
|
||||
for c in (self.channel, local_channel):
|
||||
c.send(None)
|
||||
|
||||
for c in (self.channel, local_channel):
|
||||
c.send('foo')
|
||||
|
||||
self.channel.send({'stop': True})
|
||||
self.wait()
|
||||
self.assertEqual(self.payloads[:-1], [None, None, 'foo', 'foo'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from integration import run_tests
|
||||
run_tests(IPCMessageClient, needs_daemon=False)
|
Loading…
Reference in New Issue
Block a user