mirror of
https://github.com/valitydev/salt.git
synced 2024-11-07 08:58:59 +00:00
Merge pull request #25407 from rallytime/bp-23236
Back-port #23236 to 2015.8
This commit is contained in:
commit
bd7c71e3e4
@ -36,6 +36,32 @@ class ReqChannel(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PushChannel(object):
|
||||
'''
|
||||
Factory class to create Sync channel for push side of push/pull IPC
|
||||
'''
|
||||
@staticmethod
|
||||
def factory(opts, **kwargs):
|
||||
sync = SyncWrapper(AsyncPushChannel.factory, (opts,), kwargs)
|
||||
return sync
|
||||
|
||||
def send(self, load, tries=3, timeout=60):
|
||||
'''
|
||||
Send load across IPC push
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PullChannel(object):
|
||||
'''
|
||||
Factory class to create Sync channel for pull side of push/pull IPC
|
||||
'''
|
||||
@staticmethod
|
||||
def factory(opts, **kwargs):
|
||||
sync = SyncWrapper(AsyncPullChannel.factory, (opts,), kwargs)
|
||||
return sync
|
||||
|
||||
|
||||
# TODO: better doc strings
|
||||
class AsyncChannel(object):
|
||||
'''
|
||||
@ -150,4 +176,34 @@ class AsyncPubChannel(AsyncChannel):
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AsyncPushChannel(object):
|
||||
'''
|
||||
Factory class to create IPC Push channels
|
||||
'''
|
||||
@staticmethod
|
||||
def factory(opts, **kwargs):
|
||||
'''
|
||||
If we have additional IPC transports other than UxD and TCP, add them here
|
||||
'''
|
||||
# FIXME for now, just UXD
|
||||
# Obviously, this makes the factory approach pointless, but we'll extend later
|
||||
import salt.transport.ipc
|
||||
return salt.transport.ipc.IPCMessageClient(opts, **kwargs)
|
||||
|
||||
|
||||
class AsyncPullChannel(object):
|
||||
'''
|
||||
Factory class to create IPC pull channels
|
||||
'''
|
||||
@staticmethod
|
||||
def factory(opts, **kwargs):
|
||||
'''
|
||||
If we have additional IPC transports other than UXD and TCP, add them here
|
||||
'''
|
||||
import salt.transport.ipc
|
||||
return salt.transport.ipc.IPCMessageServer(opts, **kwargs)
|
||||
|
||||
## Additional IPC messaging patterns should provide interfaces here, ala router/dealer, pub/sub, etc
|
||||
|
||||
# EOF
|
||||
|
25
salt/transport/frame.py
Normal file
25
salt/transport/frame.py
Normal file
@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
Helper functions for transport components to handle message framing
|
||||
'''
|
||||
# Import python libs
|
||||
from __future__ import absolute_import
|
||||
import msgpack
|
||||
|
||||
|
||||
def frame_msg(body, header=None, raw_body=False):
|
||||
'''
|
||||
Frame the given message with our wire protocol
|
||||
'''
|
||||
framed_msg = {}
|
||||
if header is None:
|
||||
header = {}
|
||||
|
||||
# if the body wasn't already msgpacked-- lets do that.
|
||||
if not raw_body:
|
||||
body = msgpack.dumps(body)
|
||||
|
||||
framed_msg['head'] = header
|
||||
framed_msg['body'] = body
|
||||
framed_msg_packed = msgpack.dumps(framed_msg)
|
||||
return '{0} {1}'.format(len(framed_msg_packed), framed_msg_packed)
|
325
salt/transport/ipc.py
Normal file
325
salt/transport/ipc.py
Normal file
@ -0,0 +1,325 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
IPC transport classes
|
||||
'''
|
||||
|
||||
# Import Python libs
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
import socket
|
||||
import msgpack
|
||||
import weakref
|
||||
|
||||
# Import Tornado libs
|
||||
import tornado
|
||||
import tornado.gen
|
||||
import tornado.netutil
|
||||
import tornado.concurrent
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream
|
||||
|
||||
# Import Salt libs
|
||||
import salt.transport.client
|
||||
import salt.transport.frame
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IPCServer(object):
|
||||
'''
|
||||
A Tornado IPC server very similar to Tornado's TCPServer class
|
||||
but using either UNIX domain sockets or TCP sockets
|
||||
'''
|
||||
def __init__(self, socket_path, io_loop=None, payload_handler=None):
|
||||
'''
|
||||
Create a new Tornado IPC server
|
||||
|
||||
:param IOLoop io_loop: A Tornado ioloop to handle scheduling
|
||||
:param func stream_handler: A function to customize handling of an
|
||||
incoming stream.
|
||||
'''
|
||||
self.socket_path = socket_path
|
||||
self._started = False
|
||||
self.payload_handler = payload_handler
|
||||
|
||||
# Placeholders for attributes to be populated by method calls
|
||||
self.stream = None
|
||||
self.sock = None
|
||||
self.io_loop = io_loop or IOLoop.current()
|
||||
|
||||
def start(self):
|
||||
'''
|
||||
Perform the work necessary to start up a Tornado IPC server
|
||||
|
||||
Blocks until socket is established
|
||||
|
||||
:param str socket_path: Path on the filesystem for the socket to bind to.
|
||||
This socket does not need to exist prior to calling
|
||||
this method, but parent directories should.
|
||||
'''
|
||||
# Start up the ioloop
|
||||
log.trace('IPCServer: binding to socket: {0}'.format(self.socket_path))
|
||||
self.sock = tornado.netutil.bind_unix_socket(self.socket_path)
|
||||
|
||||
tornado.netutil.add_accept_handler(
|
||||
self.sock,
|
||||
self.handle_connection,
|
||||
io_loop=self.io_loop,
|
||||
)
|
||||
self._started = True
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def handle_stream(self, stream):
|
||||
'''
|
||||
Override this to handle the streams as they arrive
|
||||
|
||||
:param IOStream stream: An IOStream for processing
|
||||
|
||||
See http://tornado.readthedocs.org/en/latest/iostream.html#tornado.iostream.IOStream
|
||||
for additional details.
|
||||
'''
|
||||
@tornado.gen.coroutine
|
||||
def _null(msg):
|
||||
raise tornado.gen.Return(None)
|
||||
|
||||
def write_callback(stream, header):
|
||||
if header.get('mid'):
|
||||
@tornado.gen.coroutine
|
||||
def return_message(msg):
|
||||
pack = salt.transport.frame.frame_msg(
|
||||
msg,
|
||||
header={'mid': header['mid']},
|
||||
raw_body=True,
|
||||
)
|
||||
yield stream.write(pack)
|
||||
return return_message
|
||||
else:
|
||||
return _null
|
||||
while not stream.closed():
|
||||
try:
|
||||
framed_msg_len = yield stream.read_until(' ')
|
||||
framed_msg_raw = yield stream.read_bytes(int(framed_msg_len.strip()))
|
||||
framed_msg = msgpack.loads(framed_msg_raw)
|
||||
body = framed_msg['body']
|
||||
self.io_loop.spawn_callback(self.payload_handler, body, write_callback(stream, framed_msg['head']))
|
||||
except Exception as exc:
|
||||
log.error('Exception occurred while handling stream: {0}'.format(exc))
|
||||
|
||||
def handle_connection(self, connection, address):
|
||||
log.trace('IPCServer: Handling connection to address: {0}'.format(address))
|
||||
try:
|
||||
stream = IOStream(
|
||||
connection,
|
||||
io_loop=self.io_loop,
|
||||
)
|
||||
self.io_loop.spawn_callback(self.handle_stream, stream)
|
||||
except Exception as exc:
|
||||
log.error('IPC streaming error: {0}'.format(exc))
|
||||
|
||||
def close(self):
|
||||
'''
|
||||
Routines to handle any cleanup before the instance shuts down.
|
||||
Sockets and filehandles should be closed explicitely, to prevent
|
||||
leaks.
|
||||
'''
|
||||
if hasattr(self.stream, 'close'):
|
||||
self.stream.close()
|
||||
if hasattr(self.sock, 'close'):
|
||||
self.sock.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
|
||||
class IPCClient(object):
|
||||
'''
|
||||
A Tornado IPC client very similar to Tornado's TCPClient class
|
||||
but using either UNIX domain sockets or TCP sockets
|
||||
|
||||
This was written because Tornado does not have its own IPC
|
||||
server/client implementation.
|
||||
|
||||
:param IOLoop io_loop: A Tornado ioloop to handle scheduling
|
||||
:param str socket_path: A path on the filesystem where a socket
|
||||
belonging to a running IPCServer can be
|
||||
found.
|
||||
'''
|
||||
|
||||
# Create singleton map between two sockets
|
||||
instance_map = weakref.WeakKeyDictionary()
|
||||
|
||||
def __new__(cls, socket_path, io_loop=None):
|
||||
io_loop = io_loop or tornado.ioloop.IOLoop.current()
|
||||
if io_loop not in IPCClient.instance_map:
|
||||
IPCClient.instance_map[io_loop] = weakref.WeakValueDictionary()
|
||||
loop_instance_map = IPCClient.instance_map[io_loop]
|
||||
|
||||
# FIXME
|
||||
key = socket_path
|
||||
|
||||
if key not in loop_instance_map:
|
||||
log.debug('Initializing new IPCClient for path: {0}'.format(key))
|
||||
new_client = object.__new__(cls)
|
||||
# FIXME
|
||||
new_client.__singleton_init__(io_loop=io_loop, socket_path=socket_path)
|
||||
loop_instance_map[key] = new_client
|
||||
else:
|
||||
log.debug('Re-using IPCClient for {0}'.format(key))
|
||||
return loop_instance_map[key]
|
||||
|
||||
def __singleton_init__(self, socket_path, io_loop=None):
|
||||
'''
|
||||
Create a new IPC client
|
||||
|
||||
IPC clients cannot bind to ports, but must connect to
|
||||
existing IPC servers. Clients can then send messages
|
||||
to the server.
|
||||
|
||||
'''
|
||||
self.io_loop = io_loop or tornado.ioloop.IOLoop.current()
|
||||
self.socket_path = socket_path
|
||||
self._closing = False
|
||||
|
||||
def __init__(self, socket_path, io_loop=None):
|
||||
# Handled by singleton __new__
|
||||
pass
|
||||
|
||||
def connected(self):
|
||||
return hasattr(self, 'stream')
|
||||
|
||||
def connect(self, callback=None):
|
||||
'''
|
||||
Connect to the IPC socket
|
||||
'''
|
||||
if hasattr(self, '_connecting_future') and not self._connecting_future.done(): # pylint: disable=E0203
|
||||
future = self._connecting_future # pylint: disable=E0203
|
||||
else:
|
||||
future = tornado.concurrent.Future()
|
||||
self._connecting_future = future
|
||||
self.io_loop.add_callback(self._connect)
|
||||
|
||||
if callback is not None:
|
||||
def handle_future(future):
|
||||
response = future.result()
|
||||
self.io_loop.add_callback(callback, response)
|
||||
future.add_done_callback(handle_future)
|
||||
return future
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def _connect(self):
|
||||
'''
|
||||
Connect to a running IPCServer
|
||||
'''
|
||||
self.stream = IOStream(
|
||||
socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
|
||||
io_loop=self.io_loop,
|
||||
)
|
||||
while True:
|
||||
if self._closing:
|
||||
break
|
||||
try:
|
||||
log.trace('IPCClient: Connecting to socket: {0}'.format(self.socket_path))
|
||||
yield self.stream.connect(self.socket_path)
|
||||
self._connecting_future.set_result(True)
|
||||
break
|
||||
except Exception as e:
|
||||
yield tornado.gen.sleep(1) # TODO: backoff
|
||||
#self._connecting_future.set_exception(e)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
'''
|
||||
Routines to handle any cleanup before the instance shuts down.
|
||||
Sockets and filehandles should be closed explicitely, to prevent
|
||||
leaks.
|
||||
'''
|
||||
self._closing = True
|
||||
if hasattr(self, 'stream'):
|
||||
self.stream.close()
|
||||
|
||||
|
||||
class IPCMessageClient(IPCClient):
|
||||
'''
|
||||
Salt IPC message client
|
||||
|
||||
Create an IPC client to send messages to an IPC server
|
||||
|
||||
An example of a very simple IPCMessageClient connecting to an IPCServer. This
|
||||
example assumes an already running IPCMessage server.
|
||||
|
||||
IMPORTANT: The below example also assumes a running IOLoop process.
|
||||
|
||||
# Import Tornado libs
|
||||
import tornado.ioloop
|
||||
|
||||
# Import Salt libs
|
||||
import salt.config
|
||||
import salt.transport.ipc
|
||||
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
|
||||
ipc_server_socket_path = '/var/run/ipc_server.ipc'
|
||||
|
||||
ipc_client = salt.transport.ipc.IPCMessageClient(ipc_server_socket_path, io_loop=io_loop)
|
||||
|
||||
# Connect to the server
|
||||
ipc_client.connect()
|
||||
|
||||
# Send some data
|
||||
ipc_client.send('Hello world')
|
||||
'''
|
||||
# FIXME timeout unimplemented
|
||||
# FIXME tries unimplemented
|
||||
@tornado.gen.coroutine
|
||||
def send(self, msg, timeout=None, tries=None):
|
||||
'''
|
||||
Send a message to an IPC socket
|
||||
|
||||
If the socket is not currently connected, a connection will be established.
|
||||
|
||||
:param dict msg: The message to be sent
|
||||
:param int timeout: Timeout when sending message (Currently unimplemented)
|
||||
'''
|
||||
if not self.connected():
|
||||
yield self.connect()
|
||||
pack = salt.transport.frame.frame_msg(msg, raw_body=True)
|
||||
yield self.stream.write(pack)
|
||||
|
||||
|
||||
class IPCMessageServer(IPCServer):
|
||||
'''
|
||||
Salt IPC message server
|
||||
|
||||
Creates a message server which can create and bind to a socket on a given
|
||||
path and then respond to messages asynchronously.
|
||||
|
||||
An example of a very simple IPCServer which prints received messages to
|
||||
a console:
|
||||
|
||||
# Import Tornado libs
|
||||
import tornado.ioloop
|
||||
|
||||
# Import Salt libs
|
||||
import salt.transport.ipc
|
||||
import salt.config
|
||||
|
||||
opts = salt.config.master_opts()
|
||||
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
ipc_server_socket_path = '/var/run/ipc_server.ipc'
|
||||
ipc_server = salt.transport.ipc.IPCMessageServer(opts, io_loop=io_loop
|
||||
stream_handler=print_to_console)
|
||||
# Bind to the socket and prepare to run
|
||||
ipc_server.start(ipc_server_socket_path)
|
||||
|
||||
# Start the server
|
||||
io_loop.start()
|
||||
|
||||
# This callback is run whenever a message is received
|
||||
def print_to_console(payload):
|
||||
print(payload)
|
||||
|
||||
See IPCMessageClient() for an example of sending messages to an IPCMessageServer instance
|
||||
'''
|
@ -38,6 +38,7 @@ class AESPubClientMixin(object):
|
||||
@tornado.gen.coroutine
|
||||
def _decode_payload(self, payload):
|
||||
# we need to decrypt it
|
||||
log.trace('Decoding payload: {0}'.format(payload))
|
||||
if payload['enc'] == 'aes':
|
||||
self._verify_master_signature(payload)
|
||||
try:
|
||||
|
@ -22,21 +22,16 @@ import salt.crypt
|
||||
import salt.utils
|
||||
import salt.utils.verify
|
||||
import salt.utils.event
|
||||
import salt.utils.async
|
||||
import salt.payload
|
||||
import salt.exceptions
|
||||
import salt.transport.frame
|
||||
import salt.transport.ipc
|
||||
import salt.transport.client
|
||||
import salt.transport.server
|
||||
import salt.transport.mixins.auth
|
||||
from salt.exceptions import SaltReqTimeoutError, SaltClientError
|
||||
|
||||
# for IPC (for now)
|
||||
import zmq
|
||||
import zmq.eventloop.ioloop
|
||||
# support pyzmq 13.0.x, TODO: remove once we force people to 14.0.x
|
||||
if not hasattr(zmq.eventloop.ioloop, 'ZMQIOLoop'):
|
||||
zmq.eventloop.ioloop.ZMQIOLoop = zmq.eventloop.ioloop.IOLoop
|
||||
import zmq.eventloop.zmqstream
|
||||
|
||||
# Import Tornado Libs
|
||||
import tornado
|
||||
import tornado.tcpserver
|
||||
@ -51,59 +46,6 @@ from Crypto.Cipher import PKCS1_OAEP
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def frame_msg(body, header=None, raw_body=False):
|
||||
'''
|
||||
Frame the given message with our wire protocol
|
||||
'''
|
||||
framed_msg = {}
|
||||
if header is None:
|
||||
header = {}
|
||||
|
||||
# if the body wasn't already msgpacked-- lets do that.
|
||||
if not raw_body:
|
||||
body = msgpack.dumps(body)
|
||||
|
||||
framed_msg['head'] = header
|
||||
framed_msg['body'] = body
|
||||
framed_msg_packed = msgpack.dumps(framed_msg)
|
||||
return '{0} {1}'.format(len(framed_msg_packed), framed_msg_packed)
|
||||
|
||||
|
||||
def socket_frame_recv(s, recv_size=4096):
|
||||
'''
|
||||
Retrieve a frame from socket
|
||||
'''
|
||||
# get the header size
|
||||
recv_buf = ''
|
||||
while ' ' not in recv_buf:
|
||||
data = s.recv(recv_size)
|
||||
if data == '':
|
||||
raise socket.error('Empty response!')
|
||||
else:
|
||||
recv_buf += data
|
||||
# once we have a space, we know how long the rest is
|
||||
header_len, buf = recv_buf.split(' ', 1)
|
||||
header_len = int(header_len)
|
||||
while len(buf) < header_len:
|
||||
data = s.recv(recv_size)
|
||||
if data == '':
|
||||
raise socket.error('msg stopped, we are missing some data!')
|
||||
else:
|
||||
buf += data
|
||||
|
||||
header = msgpack.loads(buf[:header_len])
|
||||
msg_len = int(header['msgLen'])
|
||||
buf = buf[header_len:]
|
||||
while len(buf) < msg_len:
|
||||
data = s.recv(recv_size)
|
||||
if data == '':
|
||||
raise socket.error('msg stopped, we are missing some data!')
|
||||
else:
|
||||
buf += data
|
||||
|
||||
return buf
|
||||
|
||||
|
||||
# TODO: move serial down into message library
|
||||
class AsyncTCPReqChannel(salt.transport.client.ReqChannel):
|
||||
'''
|
||||
@ -335,18 +277,20 @@ class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.tra
|
||||
try:
|
||||
payload = self._decode_payload(payload)
|
||||
except Exception:
|
||||
stream.write(frame_msg('bad load', header=header))
|
||||
stream.write(salt.transport.frame.frame_msg('bad load', header=header))
|
||||
raise tornado.gen.Return()
|
||||
|
||||
# TODO helper functions to normalize payload?
|
||||
if not isinstance(payload, dict) or not isinstance(payload.get('load'), dict):
|
||||
yield stream.write(frame_msg('payload and load must be a dict', header=header))
|
||||
yield stream.write(salt.transport.frame.frame_msg(
|
||||
'payload and load must be a dict', header=header))
|
||||
raise tornado.gen.Return()
|
||||
|
||||
# intercept the "_auth" commands, since the main daemon shouldn't know
|
||||
# anything about our key auth
|
||||
if payload['enc'] == 'clear' and payload.get('load', {}).get('cmd') == '_auth':
|
||||
yield stream.write(frame_msg(self._auth(payload['load']), header=header))
|
||||
yield stream.write(salt.transport.frame.frame_msg(
|
||||
self._auth(payload['load']), header=header))
|
||||
raise tornado.gen.Return()
|
||||
|
||||
# TODO: test
|
||||
@ -361,11 +305,11 @@ class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.tra
|
||||
|
||||
req_fun = req_opts.get('fun', 'send')
|
||||
if req_fun == 'send_clear':
|
||||
stream.write(frame_msg(ret, header=header))
|
||||
stream.write(salt.transport.frame.frame_msg(ret, header=header))
|
||||
elif req_fun == 'send':
|
||||
stream.write(frame_msg(self.crypticle.dumps(ret), header=header))
|
||||
stream.write(salt.transport.frame.frame_msg(self.crypticle.dumps(ret), header=header))
|
||||
elif req_fun == 'send_private':
|
||||
stream.write(frame_msg(self._encrypt_private(ret,
|
||||
stream.write(salt.transport.frame.frame_msg(self._encrypt_private(ret,
|
||||
req_opts['key'],
|
||||
req_opts['tgt'],
|
||||
), header=header))
|
||||
@ -422,6 +366,8 @@ class SaltMessageServer(tornado.tcpserver.TCPServer, object):
|
||||
self.clients.remove(item)
|
||||
|
||||
|
||||
# TODO consolidate with IPCClient
|
||||
# TODO: limit in-flight messages.
|
||||
# TODO: singleton? Something to not re-create the tcp connection so much
|
||||
class SaltMessageClient(object):
|
||||
'''
|
||||
@ -609,7 +555,7 @@ class SaltMessageClient(object):
|
||||
# if we don't have a send queue, we need to spawn the callback to do the sending
|
||||
if len(self.send_queue) == 0:
|
||||
self.io_loop.spawn_callback(self._stream_send)
|
||||
self.send_queue.append((message_id, frame_msg(msg, header=header)))
|
||||
self.send_queue.append((message_id, salt.transport.frame.frame_msg(msg, header=header)))
|
||||
return future
|
||||
|
||||
|
||||
@ -625,15 +571,16 @@ class PubServer(tornado.tcpserver.TCPServer, object):
|
||||
log.trace('Subscriber at {0} connected'.format(address))
|
||||
self.clients.append((stream, address))
|
||||
|
||||
# TODO: ACK the publish through IPC
|
||||
@tornado.gen.coroutine
|
||||
def publish_payload(self, package):
|
||||
log.trace('TCP PubServer starting to publish payload')
|
||||
package = package[0] # ZMQ (The IPC calling us) ism :/
|
||||
payload = frame_msg(salt.payload.unpackage(package)['payload'], raw_body=True)
|
||||
def publish_payload(self, payload, _):
|
||||
payload = salt.transport.frame.frame_msg(payload['payload'], raw_body=True)
|
||||
log.debug('TCP PubServer sending payload: {0}'.format(payload))
|
||||
to_remove = []
|
||||
for item in self.clients:
|
||||
client, address = item
|
||||
try:
|
||||
# Write the packed str
|
||||
f = client.write(payload)
|
||||
self.io_loop.add_future(f, lambda f: True)
|
||||
except tornado.iostream.StreamClosedError:
|
||||
@ -647,9 +594,10 @@ class PubServer(tornado.tcpserver.TCPServer, object):
|
||||
|
||||
|
||||
class TCPPubServerChannel(salt.transport.server.PubServerChannel):
|
||||
def __init__(self, opts):
|
||||
def __init__(self, opts, io_loop=None):
|
||||
self.opts = opts
|
||||
self.serial = salt.payload.Serial(self.opts) # TODO: in init?
|
||||
self.io_loop = io_loop or tornado.ioloop.IOLoop.current()
|
||||
|
||||
def _publish_daemon(self):
|
||||
'''
|
||||
@ -657,35 +605,28 @@ class TCPPubServerChannel(salt.transport.server.PubServerChannel):
|
||||
'''
|
||||
salt.utils.appendproctitle(self.__class__.__name__)
|
||||
|
||||
# Set up the context
|
||||
context = zmq.Context(1)
|
||||
# Prepare minion pull socket
|
||||
pull_sock = context.socket(zmq.PULL)
|
||||
pull_uri = 'ipc://{0}'.format(
|
||||
os.path.join(self.opts['sock_dir'], 'publish_pull_tcp.ipc')
|
||||
# Spin up the publisher
|
||||
pub_server = PubServer(io_loop=self.io_loop)
|
||||
pub_server.listen(int(self.opts['publish_port']), address=self.opts['interface'])
|
||||
|
||||
# Set up Salt IPC server
|
||||
pull_uri = os.path.join(self.opts['sock_dir'], 'publish_pull.ipc')
|
||||
pull_sock = salt.transport.ipc.IPCMessageServer(
|
||||
pull_uri,
|
||||
io_loop=self.io_loop,
|
||||
payload_handler=pub_server.publish_payload,
|
||||
)
|
||||
salt.utils.zeromq.check_ipc_path_max_len(pull_uri)
|
||||
|
||||
# Securely create socket
|
||||
log.info('Starting the Salt Puller on {0}'.format(pull_uri))
|
||||
old_umask = os.umask(0o177)
|
||||
try:
|
||||
pull_sock.bind(pull_uri)
|
||||
pull_sock.start()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
# load up the IOLoop
|
||||
io_loop = zmq.eventloop.ioloop.ZMQIOLoop()
|
||||
# add the publisher
|
||||
pub_server = PubServer(io_loop=io_loop)
|
||||
pub_server.listen(int(self.opts['publish_port']), address=self.opts['interface'])
|
||||
|
||||
# add our IPC
|
||||
stream = zmq.eventloop.zmqstream.ZMQStream(pull_sock, io_loop=io_loop)
|
||||
stream.on_recv(pub_server.publish_payload)
|
||||
|
||||
# run forever
|
||||
io_loop.start()
|
||||
self.io_loop.start()
|
||||
|
||||
def pre_fork(self, process_manager):
|
||||
'''
|
||||
@ -707,17 +648,20 @@ class TCPPubServerChannel(salt.transport.server.PubServerChannel):
|
||||
master_pem_path = os.path.join(self.opts['pki_dir'], 'master.pem')
|
||||
log.debug("Signing data packet")
|
||||
payload['sig'] = salt.crypt.sign_message(master_pem_path, payload['load'])
|
||||
# Send 0MQ to the publisher
|
||||
context = zmq.Context(1)
|
||||
pub_sock = context.socket(zmq.PUSH)
|
||||
pull_uri = 'ipc://{0}'.format(
|
||||
os.path.join(self.opts['sock_dir'], 'publish_pull_tcp.ipc')
|
||||
)
|
||||
pub_sock.connect(pull_uri)
|
||||
# Use the Salt IPC server
|
||||
pull_uri = os.path.join(self.opts['sock_dir'], 'publish_pull.ipc')
|
||||
# TODO: switch to the actual async interface
|
||||
#pub_sock = salt.transport.ipc.IPCMessageClient(self.opts, io_loop=self.io_loop)
|
||||
pub_sock = salt.utils.async.SyncWrapper(
|
||||
salt.transport.ipc.IPCMessageClient,
|
||||
(pull_uri,)
|
||||
)
|
||||
pub_sock.connect()
|
||||
|
||||
int_payload = {'payload': self.serial.dumps(payload)}
|
||||
|
||||
# add some targeting stuff for lists only (for now)
|
||||
if load['tgt_type'] == 'list':
|
||||
int_payload['topic_lst'] = load['tgt']
|
||||
|
||||
pub_sock.send(self.serial.dumps(int_payload))
|
||||
# Send it over IPC!
|
||||
pub_sock.send(int_payload)
|
||||
|
146
tests/unit/transport/ipc_test.py
Normal file
146
tests/unit/transport/ipc_test.py
Normal file
@ -0,0 +1,146 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
:codeauthor: :email:`Mike Place <mp@saltstack.com>`
|
||||
'''
|
||||
|
||||
# Import python libs
|
||||
from __future__ import absolute_import
|
||||
import os
|
||||
import logging
|
||||
|
||||
import tornado.gen
|
||||
import tornado.ioloop
|
||||
import tornado.testing
|
||||
|
||||
import salt.utils
|
||||
import salt.config
|
||||
import salt.exceptions
|
||||
import salt.transport.ipc
|
||||
import salt.transport.server
|
||||
import salt.transport.client
|
||||
|
||||
from salt.ext.six.moves import range
|
||||
|
||||
# Import Salt Testing libs
|
||||
import integration
|
||||
|
||||
from salttesting.mock import MagicMock
|
||||
from salttesting.helpers import ensure_in_syspath
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ensure_in_syspath('../')
|
||||
|
||||
|
||||
class BaseIPCReqCase(tornado.testing.AsyncTestCase):
|
||||
'''
|
||||
Test the req server/client pair
|
||||
'''
|
||||
def setUp(self):
|
||||
super(BaseIPCReqCase, self).setUp()
|
||||
self._start_handlers = dict(self.io_loop._handlers)
|
||||
self.socket_path = os.path.join(integration.TMP, 'ipc_test.ipc')
|
||||
|
||||
self.server_channel = salt.transport.ipc.IPCMessageServer(
|
||||
self.socket_path,
|
||||
io_loop=self.io_loop,
|
||||
payload_handler=self._handle_payload,
|
||||
)
|
||||
self.server_channel.start()
|
||||
|
||||
self.payloads = []
|
||||
|
||||
def tearDown(self):
|
||||
super(BaseIPCReqCase, self).tearDown()
|
||||
failures = []
|
||||
self.server_channel.close()
|
||||
os.unlink(self.socket_path)
|
||||
for k, v in self.io_loop._handlers.iteritems():
|
||||
if self._start_handlers.get(k) != v:
|
||||
failures.append((k, v))
|
||||
if len(failures) > 0:
|
||||
raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
|
||||
|
||||
@tornado.gen.coroutine
|
||||
def _handle_payload(self, payload, reply_func):
|
||||
self.payloads.append(payload)
|
||||
yield reply_func(payload)
|
||||
if isinstance(payload, dict) and payload.get('stop'):
|
||||
self.stop()
|
||||
|
||||
|
||||
class IPCMessageClient(BaseIPCReqCase):
|
||||
'''
|
||||
Test all of the clear msg stuff
|
||||
'''
|
||||
|
||||
def _get_channel(self):
|
||||
channel = salt.transport.ipc.IPCMessageClient(
|
||||
socket_path=self.socket_path,
|
||||
io_loop=self.io_loop,
|
||||
)
|
||||
channel.connect(callback=self.stop)
|
||||
self.wait()
|
||||
return channel
|
||||
|
||||
def setUp(self):
|
||||
super(IPCMessageClient, self).setUp()
|
||||
self.channel = self._get_channel()
|
||||
|
||||
def tearDown(self):
|
||||
super(IPCMessageClient, self).setUp()
|
||||
self.channel.close()
|
||||
|
||||
def test_basic_send(self):
|
||||
msg = {'foo': 'bar', 'stop': True}
|
||||
self.channel.send(msg)
|
||||
self.wait()
|
||||
self.assertEqual(self.payloads[0], msg)
|
||||
|
||||
def test_many_send(self):
|
||||
msgs = []
|
||||
self.server_channel.stream_handler = MagicMock()
|
||||
|
||||
for i in range(0, 1000):
|
||||
msgs.append('test_many_send_{0}'.format(i))
|
||||
|
||||
for i in msgs:
|
||||
self.channel.send(i)
|
||||
self.channel.send({'stop': True})
|
||||
self.wait()
|
||||
self.assertEqual(self.payloads[:-1], msgs)
|
||||
|
||||
def test_very_big_message(self):
|
||||
long_str = ''.join([str(num) for num in range(10**5)])
|
||||
msg = {'long_str': long_str, 'stop': True}
|
||||
self.channel.send(msg)
|
||||
self.wait()
|
||||
self.assertEqual(msg, self.payloads[0])
|
||||
|
||||
def test_multistream_sends(self):
|
||||
local_channel = self._get_channel()
|
||||
|
||||
for c in (self.channel, local_channel):
|
||||
c.send('foo')
|
||||
|
||||
self.channel.send({'stop': True})
|
||||
self.wait()
|
||||
self.assertEqual(self.payloads[:-1], ['foo', 'foo'])
|
||||
|
||||
def test_multistream_errors(self):
|
||||
local_channel = self._get_channel()
|
||||
|
||||
for c in (self.channel, local_channel):
|
||||
c.send(None)
|
||||
|
||||
for c in (self.channel, local_channel):
|
||||
c.send('foo')
|
||||
|
||||
self.channel.send({'stop': True})
|
||||
self.wait()
|
||||
self.assertEqual(self.payloads[:-1], [None, None, 'foo', 'foo'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from integration import run_tests
|
||||
run_tests(IPCMessageClient, needs_daemon=False)
|
Loading…
Reference in New Issue
Block a user