Merge pull request #25407 from rallytime/bp-23236

Back-port #23236 to 2015.8
This commit is contained in:
Mike Place 2015-07-14 10:09:21 -06:00
commit bd7c71e3e4
6 changed files with 598 additions and 101 deletions

View File

@ -36,6 +36,32 @@ class ReqChannel(object):
raise NotImplementedError()
class PushChannel(object):
'''
Factory class to create Sync channel for push side of push/pull IPC
'''
@staticmethod
def factory(opts, **kwargs):
sync = SyncWrapper(AsyncPushChannel.factory, (opts,), kwargs)
return sync
def send(self, load, tries=3, timeout=60):
'''
Send load across IPC push
'''
raise NotImplementedError()
class PullChannel(object):
'''
Factory class to create Sync channel for pull side of push/pull IPC
'''
@staticmethod
def factory(opts, **kwargs):
sync = SyncWrapper(AsyncPullChannel.factory, (opts,), kwargs)
return sync
# TODO: better doc strings
class AsyncChannel(object):
'''
@ -150,4 +176,34 @@ class AsyncPubChannel(AsyncChannel):
'''
raise NotImplementedError()
class AsyncPushChannel(object):
'''
Factory class to create IPC Push channels
'''
@staticmethod
def factory(opts, **kwargs):
'''
If we have additional IPC transports other than UxD and TCP, add them here
'''
# FIXME for now, just UXD
# Obviously, this makes the factory approach pointless, but we'll extend later
import salt.transport.ipc
return salt.transport.ipc.IPCMessageClient(opts, **kwargs)
class AsyncPullChannel(object):
'''
Factory class to create IPC pull channels
'''
@staticmethod
def factory(opts, **kwargs):
'''
If we have additional IPC transports other than UXD and TCP, add them here
'''
import salt.transport.ipc
return salt.transport.ipc.IPCMessageServer(opts, **kwargs)
## Additional IPC messaging patterns should provide interfaces here, ala router/dealer, pub/sub, etc
# EOF

25
salt/transport/frame.py Normal file
View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
'''
Helper functions for transport components to handle message framing
'''
# Import python libs
from __future__ import absolute_import
import msgpack
def frame_msg(body, header=None, raw_body=False):
'''
Frame the given message with our wire protocol
'''
framed_msg = {}
if header is None:
header = {}
# if the body wasn't already msgpacked-- lets do that.
if not raw_body:
body = msgpack.dumps(body)
framed_msg['head'] = header
framed_msg['body'] = body
framed_msg_packed = msgpack.dumps(framed_msg)
return '{0} {1}'.format(len(framed_msg_packed), framed_msg_packed)

325
salt/transport/ipc.py Normal file
View File

@ -0,0 +1,325 @@
# -*- coding: utf-8 -*-
'''
IPC transport classes
'''
# Import Python libs
from __future__ import absolute_import
import logging
import socket
import msgpack
import weakref
# Import Tornado libs
import tornado
import tornado.gen
import tornado.netutil
import tornado.concurrent
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
# Import Salt libs
import salt.transport.client
import salt.transport.frame
log = logging.getLogger(__name__)
class IPCServer(object):
'''
A Tornado IPC server very similar to Tornado's TCPServer class
but using either UNIX domain sockets or TCP sockets
'''
def __init__(self, socket_path, io_loop=None, payload_handler=None):
'''
Create a new Tornado IPC server
:param IOLoop io_loop: A Tornado ioloop to handle scheduling
:param func stream_handler: A function to customize handling of an
incoming stream.
'''
self.socket_path = socket_path
self._started = False
self.payload_handler = payload_handler
# Placeholders for attributes to be populated by method calls
self.stream = None
self.sock = None
self.io_loop = io_loop or IOLoop.current()
def start(self):
'''
Perform the work necessary to start up a Tornado IPC server
Blocks until socket is established
:param str socket_path: Path on the filesystem for the socket to bind to.
This socket does not need to exist prior to calling
this method, but parent directories should.
'''
# Start up the ioloop
log.trace('IPCServer: binding to socket: {0}'.format(self.socket_path))
self.sock = tornado.netutil.bind_unix_socket(self.socket_path)
tornado.netutil.add_accept_handler(
self.sock,
self.handle_connection,
io_loop=self.io_loop,
)
self._started = True
@tornado.gen.coroutine
def handle_stream(self, stream):
'''
Override this to handle the streams as they arrive
:param IOStream stream: An IOStream for processing
See http://tornado.readthedocs.org/en/latest/iostream.html#tornado.iostream.IOStream
for additional details.
'''
@tornado.gen.coroutine
def _null(msg):
raise tornado.gen.Return(None)
def write_callback(stream, header):
if header.get('mid'):
@tornado.gen.coroutine
def return_message(msg):
pack = salt.transport.frame.frame_msg(
msg,
header={'mid': header['mid']},
raw_body=True,
)
yield stream.write(pack)
return return_message
else:
return _null
while not stream.closed():
try:
framed_msg_len = yield stream.read_until(' ')
framed_msg_raw = yield stream.read_bytes(int(framed_msg_len.strip()))
framed_msg = msgpack.loads(framed_msg_raw)
body = framed_msg['body']
self.io_loop.spawn_callback(self.payload_handler, body, write_callback(stream, framed_msg['head']))
except Exception as exc:
log.error('Exception occurred while handling stream: {0}'.format(exc))
def handle_connection(self, connection, address):
log.trace('IPCServer: Handling connection to address: {0}'.format(address))
try:
stream = IOStream(
connection,
io_loop=self.io_loop,
)
self.io_loop.spawn_callback(self.handle_stream, stream)
except Exception as exc:
log.error('IPC streaming error: {0}'.format(exc))
def close(self):
'''
Routines to handle any cleanup before the instance shuts down.
Sockets and filehandles should be closed explicitely, to prevent
leaks.
'''
if hasattr(self.stream, 'close'):
self.stream.close()
if hasattr(self.sock, 'close'):
self.sock.close()
def __del__(self):
self.close()
class IPCClient(object):
'''
A Tornado IPC client very similar to Tornado's TCPClient class
but using either UNIX domain sockets or TCP sockets
This was written because Tornado does not have its own IPC
server/client implementation.
:param IOLoop io_loop: A Tornado ioloop to handle scheduling
:param str socket_path: A path on the filesystem where a socket
belonging to a running IPCServer can be
found.
'''
# Create singleton map between two sockets
instance_map = weakref.WeakKeyDictionary()
def __new__(cls, socket_path, io_loop=None):
io_loop = io_loop or tornado.ioloop.IOLoop.current()
if io_loop not in IPCClient.instance_map:
IPCClient.instance_map[io_loop] = weakref.WeakValueDictionary()
loop_instance_map = IPCClient.instance_map[io_loop]
# FIXME
key = socket_path
if key not in loop_instance_map:
log.debug('Initializing new IPCClient for path: {0}'.format(key))
new_client = object.__new__(cls)
# FIXME
new_client.__singleton_init__(io_loop=io_loop, socket_path=socket_path)
loop_instance_map[key] = new_client
else:
log.debug('Re-using IPCClient for {0}'.format(key))
return loop_instance_map[key]
def __singleton_init__(self, socket_path, io_loop=None):
'''
Create a new IPC client
IPC clients cannot bind to ports, but must connect to
existing IPC servers. Clients can then send messages
to the server.
'''
self.io_loop = io_loop or tornado.ioloop.IOLoop.current()
self.socket_path = socket_path
self._closing = False
def __init__(self, socket_path, io_loop=None):
# Handled by singleton __new__
pass
def connected(self):
return hasattr(self, 'stream')
def connect(self, callback=None):
'''
Connect to the IPC socket
'''
if hasattr(self, '_connecting_future') and not self._connecting_future.done(): # pylint: disable=E0203
future = self._connecting_future # pylint: disable=E0203
else:
future = tornado.concurrent.Future()
self._connecting_future = future
self.io_loop.add_callback(self._connect)
if callback is not None:
def handle_future(future):
response = future.result()
self.io_loop.add_callback(callback, response)
future.add_done_callback(handle_future)
return future
@tornado.gen.coroutine
def _connect(self):
'''
Connect to a running IPCServer
'''
self.stream = IOStream(
socket.socket(socket.AF_UNIX, socket.SOCK_STREAM),
io_loop=self.io_loop,
)
while True:
if self._closing:
break
try:
log.trace('IPCClient: Connecting to socket: {0}'.format(self.socket_path))
yield self.stream.connect(self.socket_path)
self._connecting_future.set_result(True)
break
except Exception as e:
yield tornado.gen.sleep(1) # TODO: backoff
#self._connecting_future.set_exception(e)
def __del__(self):
self.close()
def close(self):
'''
Routines to handle any cleanup before the instance shuts down.
Sockets and filehandles should be closed explicitely, to prevent
leaks.
'''
self._closing = True
if hasattr(self, 'stream'):
self.stream.close()
class IPCMessageClient(IPCClient):
'''
Salt IPC message client
Create an IPC client to send messages to an IPC server
An example of a very simple IPCMessageClient connecting to an IPCServer. This
example assumes an already running IPCMessage server.
IMPORTANT: The below example also assumes a running IOLoop process.
# Import Tornado libs
import tornado.ioloop
# Import Salt libs
import salt.config
import salt.transport.ipc
io_loop = tornado.ioloop.IOLoop.current()
ipc_server_socket_path = '/var/run/ipc_server.ipc'
ipc_client = salt.transport.ipc.IPCMessageClient(ipc_server_socket_path, io_loop=io_loop)
# Connect to the server
ipc_client.connect()
# Send some data
ipc_client.send('Hello world')
'''
# FIXME timeout unimplemented
# FIXME tries unimplemented
@tornado.gen.coroutine
def send(self, msg, timeout=None, tries=None):
'''
Send a message to an IPC socket
If the socket is not currently connected, a connection will be established.
:param dict msg: The message to be sent
:param int timeout: Timeout when sending message (Currently unimplemented)
'''
if not self.connected():
yield self.connect()
pack = salt.transport.frame.frame_msg(msg, raw_body=True)
yield self.stream.write(pack)
class IPCMessageServer(IPCServer):
'''
Salt IPC message server
Creates a message server which can create and bind to a socket on a given
path and then respond to messages asynchronously.
An example of a very simple IPCServer which prints received messages to
a console:
# Import Tornado libs
import tornado.ioloop
# Import Salt libs
import salt.transport.ipc
import salt.config
opts = salt.config.master_opts()
io_loop = tornado.ioloop.IOLoop.current()
ipc_server_socket_path = '/var/run/ipc_server.ipc'
ipc_server = salt.transport.ipc.IPCMessageServer(opts, io_loop=io_loop
stream_handler=print_to_console)
# Bind to the socket and prepare to run
ipc_server.start(ipc_server_socket_path)
# Start the server
io_loop.start()
# This callback is run whenever a message is received
def print_to_console(payload):
print(payload)
See IPCMessageClient() for an example of sending messages to an IPCMessageServer instance
'''

View File

@ -38,6 +38,7 @@ class AESPubClientMixin(object):
@tornado.gen.coroutine
def _decode_payload(self, payload):
# we need to decrypt it
log.trace('Decoding payload: {0}'.format(payload))
if payload['enc'] == 'aes':
self._verify_master_signature(payload)
try:

View File

@ -22,21 +22,16 @@ import salt.crypt
import salt.utils
import salt.utils.verify
import salt.utils.event
import salt.utils.async
import salt.payload
import salt.exceptions
import salt.transport.frame
import salt.transport.ipc
import salt.transport.client
import salt.transport.server
import salt.transport.mixins.auth
from salt.exceptions import SaltReqTimeoutError, SaltClientError
# for IPC (for now)
import zmq
import zmq.eventloop.ioloop
# support pyzmq 13.0.x, TODO: remove once we force people to 14.0.x
if not hasattr(zmq.eventloop.ioloop, 'ZMQIOLoop'):
zmq.eventloop.ioloop.ZMQIOLoop = zmq.eventloop.ioloop.IOLoop
import zmq.eventloop.zmqstream
# Import Tornado Libs
import tornado
import tornado.tcpserver
@ -51,59 +46,6 @@ from Crypto.Cipher import PKCS1_OAEP
log = logging.getLogger(__name__)
def frame_msg(body, header=None, raw_body=False):
'''
Frame the given message with our wire protocol
'''
framed_msg = {}
if header is None:
header = {}
# if the body wasn't already msgpacked-- lets do that.
if not raw_body:
body = msgpack.dumps(body)
framed_msg['head'] = header
framed_msg['body'] = body
framed_msg_packed = msgpack.dumps(framed_msg)
return '{0} {1}'.format(len(framed_msg_packed), framed_msg_packed)
def socket_frame_recv(s, recv_size=4096):
'''
Retrieve a frame from socket
'''
# get the header size
recv_buf = ''
while ' ' not in recv_buf:
data = s.recv(recv_size)
if data == '':
raise socket.error('Empty response!')
else:
recv_buf += data
# once we have a space, we know how long the rest is
header_len, buf = recv_buf.split(' ', 1)
header_len = int(header_len)
while len(buf) < header_len:
data = s.recv(recv_size)
if data == '':
raise socket.error('msg stopped, we are missing some data!')
else:
buf += data
header = msgpack.loads(buf[:header_len])
msg_len = int(header['msgLen'])
buf = buf[header_len:]
while len(buf) < msg_len:
data = s.recv(recv_size)
if data == '':
raise socket.error('msg stopped, we are missing some data!')
else:
buf += data
return buf
# TODO: move serial down into message library
class AsyncTCPReqChannel(salt.transport.client.ReqChannel):
'''
@ -335,18 +277,20 @@ class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.tra
try:
payload = self._decode_payload(payload)
except Exception:
stream.write(frame_msg('bad load', header=header))
stream.write(salt.transport.frame.frame_msg('bad load', header=header))
raise tornado.gen.Return()
# TODO helper functions to normalize payload?
if not isinstance(payload, dict) or not isinstance(payload.get('load'), dict):
yield stream.write(frame_msg('payload and load must be a dict', header=header))
yield stream.write(salt.transport.frame.frame_msg(
'payload and load must be a dict', header=header))
raise tornado.gen.Return()
# intercept the "_auth" commands, since the main daemon shouldn't know
# anything about our key auth
if payload['enc'] == 'clear' and payload.get('load', {}).get('cmd') == '_auth':
yield stream.write(frame_msg(self._auth(payload['load']), header=header))
yield stream.write(salt.transport.frame.frame_msg(
self._auth(payload['load']), header=header))
raise tornado.gen.Return()
# TODO: test
@ -361,11 +305,11 @@ class TCPReqServerChannel(salt.transport.mixins.auth.AESReqServerMixin, salt.tra
req_fun = req_opts.get('fun', 'send')
if req_fun == 'send_clear':
stream.write(frame_msg(ret, header=header))
stream.write(salt.transport.frame.frame_msg(ret, header=header))
elif req_fun == 'send':
stream.write(frame_msg(self.crypticle.dumps(ret), header=header))
stream.write(salt.transport.frame.frame_msg(self.crypticle.dumps(ret), header=header))
elif req_fun == 'send_private':
stream.write(frame_msg(self._encrypt_private(ret,
stream.write(salt.transport.frame.frame_msg(self._encrypt_private(ret,
req_opts['key'],
req_opts['tgt'],
), header=header))
@ -422,6 +366,8 @@ class SaltMessageServer(tornado.tcpserver.TCPServer, object):
self.clients.remove(item)
# TODO consolidate with IPCClient
# TODO: limit in-flight messages.
# TODO: singleton? Something to not re-create the tcp connection so much
class SaltMessageClient(object):
'''
@ -609,7 +555,7 @@ class SaltMessageClient(object):
# if we don't have a send queue, we need to spawn the callback to do the sending
if len(self.send_queue) == 0:
self.io_loop.spawn_callback(self._stream_send)
self.send_queue.append((message_id, frame_msg(msg, header=header)))
self.send_queue.append((message_id, salt.transport.frame.frame_msg(msg, header=header)))
return future
@ -625,15 +571,16 @@ class PubServer(tornado.tcpserver.TCPServer, object):
log.trace('Subscriber at {0} connected'.format(address))
self.clients.append((stream, address))
# TODO: ACK the publish through IPC
@tornado.gen.coroutine
def publish_payload(self, package):
log.trace('TCP PubServer starting to publish payload')
package = package[0] # ZMQ (The IPC calling us) ism :/
payload = frame_msg(salt.payload.unpackage(package)['payload'], raw_body=True)
def publish_payload(self, payload, _):
payload = salt.transport.frame.frame_msg(payload['payload'], raw_body=True)
log.debug('TCP PubServer sending payload: {0}'.format(payload))
to_remove = []
for item in self.clients:
client, address = item
try:
# Write the packed str
f = client.write(payload)
self.io_loop.add_future(f, lambda f: True)
except tornado.iostream.StreamClosedError:
@ -647,9 +594,10 @@ class PubServer(tornado.tcpserver.TCPServer, object):
class TCPPubServerChannel(salt.transport.server.PubServerChannel):
def __init__(self, opts):
def __init__(self, opts, io_loop=None):
self.opts = opts
self.serial = salt.payload.Serial(self.opts) # TODO: in init?
self.io_loop = io_loop or tornado.ioloop.IOLoop.current()
def _publish_daemon(self):
'''
@ -657,35 +605,28 @@ class TCPPubServerChannel(salt.transport.server.PubServerChannel):
'''
salt.utils.appendproctitle(self.__class__.__name__)
# Set up the context
context = zmq.Context(1)
# Prepare minion pull socket
pull_sock = context.socket(zmq.PULL)
pull_uri = 'ipc://{0}'.format(
os.path.join(self.opts['sock_dir'], 'publish_pull_tcp.ipc')
# Spin up the publisher
pub_server = PubServer(io_loop=self.io_loop)
pub_server.listen(int(self.opts['publish_port']), address=self.opts['interface'])
# Set up Salt IPC server
pull_uri = os.path.join(self.opts['sock_dir'], 'publish_pull.ipc')
pull_sock = salt.transport.ipc.IPCMessageServer(
pull_uri,
io_loop=self.io_loop,
payload_handler=pub_server.publish_payload,
)
salt.utils.zeromq.check_ipc_path_max_len(pull_uri)
# Securely create socket
log.info('Starting the Salt Puller on {0}'.format(pull_uri))
old_umask = os.umask(0o177)
try:
pull_sock.bind(pull_uri)
pull_sock.start()
finally:
os.umask(old_umask)
# load up the IOLoop
io_loop = zmq.eventloop.ioloop.ZMQIOLoop()
# add the publisher
pub_server = PubServer(io_loop=io_loop)
pub_server.listen(int(self.opts['publish_port']), address=self.opts['interface'])
# add our IPC
stream = zmq.eventloop.zmqstream.ZMQStream(pull_sock, io_loop=io_loop)
stream.on_recv(pub_server.publish_payload)
# run forever
io_loop.start()
self.io_loop.start()
def pre_fork(self, process_manager):
'''
@ -707,17 +648,20 @@ class TCPPubServerChannel(salt.transport.server.PubServerChannel):
master_pem_path = os.path.join(self.opts['pki_dir'], 'master.pem')
log.debug("Signing data packet")
payload['sig'] = salt.crypt.sign_message(master_pem_path, payload['load'])
# Send 0MQ to the publisher
context = zmq.Context(1)
pub_sock = context.socket(zmq.PUSH)
pull_uri = 'ipc://{0}'.format(
os.path.join(self.opts['sock_dir'], 'publish_pull_tcp.ipc')
)
pub_sock.connect(pull_uri)
# Use the Salt IPC server
pull_uri = os.path.join(self.opts['sock_dir'], 'publish_pull.ipc')
# TODO: switch to the actual async interface
#pub_sock = salt.transport.ipc.IPCMessageClient(self.opts, io_loop=self.io_loop)
pub_sock = salt.utils.async.SyncWrapper(
salt.transport.ipc.IPCMessageClient,
(pull_uri,)
)
pub_sock.connect()
int_payload = {'payload': self.serial.dumps(payload)}
# add some targeting stuff for lists only (for now)
if load['tgt_type'] == 'list':
int_payload['topic_lst'] = load['tgt']
pub_sock.send(self.serial.dumps(int_payload))
# Send it over IPC!
pub_sock.send(int_payload)

View File

@ -0,0 +1,146 @@
# -*- coding: utf-8 -*-
'''
:codeauthor: :email:`Mike Place <mp@saltstack.com>`
'''
# Import python libs
from __future__ import absolute_import
import os
import logging
import tornado.gen
import tornado.ioloop
import tornado.testing
import salt.utils
import salt.config
import salt.exceptions
import salt.transport.ipc
import salt.transport.server
import salt.transport.client
from salt.ext.six.moves import range
# Import Salt Testing libs
import integration
from salttesting.mock import MagicMock
from salttesting.helpers import ensure_in_syspath
log = logging.getLogger(__name__)
ensure_in_syspath('../')
class BaseIPCReqCase(tornado.testing.AsyncTestCase):
'''
Test the req server/client pair
'''
def setUp(self):
super(BaseIPCReqCase, self).setUp()
self._start_handlers = dict(self.io_loop._handlers)
self.socket_path = os.path.join(integration.TMP, 'ipc_test.ipc')
self.server_channel = salt.transport.ipc.IPCMessageServer(
self.socket_path,
io_loop=self.io_loop,
payload_handler=self._handle_payload,
)
self.server_channel.start()
self.payloads = []
def tearDown(self):
super(BaseIPCReqCase, self).tearDown()
failures = []
self.server_channel.close()
os.unlink(self.socket_path)
for k, v in self.io_loop._handlers.iteritems():
if self._start_handlers.get(k) != v:
failures.append((k, v))
if len(failures) > 0:
raise Exception('FDs still attached to the IOLoop: {0}'.format(failures))
@tornado.gen.coroutine
def _handle_payload(self, payload, reply_func):
self.payloads.append(payload)
yield reply_func(payload)
if isinstance(payload, dict) and payload.get('stop'):
self.stop()
class IPCMessageClient(BaseIPCReqCase):
'''
Test all of the clear msg stuff
'''
def _get_channel(self):
channel = salt.transport.ipc.IPCMessageClient(
socket_path=self.socket_path,
io_loop=self.io_loop,
)
channel.connect(callback=self.stop)
self.wait()
return channel
def setUp(self):
super(IPCMessageClient, self).setUp()
self.channel = self._get_channel()
def tearDown(self):
super(IPCMessageClient, self).setUp()
self.channel.close()
def test_basic_send(self):
msg = {'foo': 'bar', 'stop': True}
self.channel.send(msg)
self.wait()
self.assertEqual(self.payloads[0], msg)
def test_many_send(self):
msgs = []
self.server_channel.stream_handler = MagicMock()
for i in range(0, 1000):
msgs.append('test_many_send_{0}'.format(i))
for i in msgs:
self.channel.send(i)
self.channel.send({'stop': True})
self.wait()
self.assertEqual(self.payloads[:-1], msgs)
def test_very_big_message(self):
long_str = ''.join([str(num) for num in range(10**5)])
msg = {'long_str': long_str, 'stop': True}
self.channel.send(msg)
self.wait()
self.assertEqual(msg, self.payloads[0])
def test_multistream_sends(self):
local_channel = self._get_channel()
for c in (self.channel, local_channel):
c.send('foo')
self.channel.send({'stop': True})
self.wait()
self.assertEqual(self.payloads[:-1], ['foo', 'foo'])
def test_multistream_errors(self):
local_channel = self._get_channel()
for c in (self.channel, local_channel):
c.send(None)
for c in (self.channel, local_channel):
c.send('foo')
self.channel.send({'stop': True})
self.wait()
self.assertEqual(self.payloads[:-1], [None, None, 'foo', 'foo'])
if __name__ == '__main__':
from integration import run_tests
run_tests(IPCMessageClient, needs_daemon=False)