diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index a709d5afe8..ee2309dbcd 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -13,6 +13,7 @@ import signal import hashlib import logging import weakref +import threading from random import randint # Import Salt Libs @@ -738,6 +739,9 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel): ''' Encapsulate synchronous operations for a publisher channel ''' + + _sock_data = threading.local() + def __init__(self, opts): self.opts = opts self.serial = salt.payload.Serial(self.opts) # TODO: in init? @@ -773,9 +777,11 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel): # IPv6 sockets work for both IPv6 and IPv4 addresses pub_sock.setsockopt(zmq.IPV4ONLY, 0) pub_sock.setsockopt(zmq.BACKLOG, self.opts.get('zmq_backlog', 1000)) + pub_sock.setsockopt(zmq.LINGER, -1) pub_uri = 'tcp://{interface}:{publish_port}'.format(**self.opts) # Prepare minion pull socket pull_sock = context.socket(zmq.PULL) + pull_sock.setsockopt(zmq.LINGER, -1) if self.opts.get('ipc_mode', '') == 'tcp': pull_uri = 'tcp://127.0.0.1:{0}'.format( @@ -838,15 +844,14 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel): raise exc except KeyboardInterrupt: - # Cleanly close the sockets if we're shutting down - if pub_sock.closed is False: - pub_sock.setsockopt(zmq.LINGER, 1) - pub_sock.close() - if pull_sock.closed is False: - pull_sock.setsockopt(zmq.LINGER, 1) - pull_sock.close() - if context.closed is False: - context.term() + log.trace('Publish daemon caught Keyboard interupt, tearing down') + # Cleanly close the sockets if we're shutting down + if pub_sock.closed is False: + pub_sock.close() + if pull_sock.closed is False: + pull_sock.close() + if context.closed is False: + context.term() def pre_fork(self, process_manager, kwargs=None): ''' @@ -858,23 +863,29 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel): ''' process_manager.add_process(self._publish_daemon, kwargs=kwargs) - def publish(self, load): + @property + def pub_sock(self): ''' - Publish "load" to minions - - :param dict load: A load to be sent across the wire to minions + This thread's zmq publisher socket. This socket is stored on the class + so that multiple instantiations in the same thread will re-use a single + zmq socket. ''' - payload = {'enc': 'aes'} + try: + return self._sock_data.sock + except AttributeError: + pass - crypticle = salt.crypt.Crypticle(self.opts, salt.master.SMaster.secrets['aes']['secret'].value) - payload['load'] = crypticle.dumps(load) - if self.opts['sign_pub_messages']: - master_pem_path = os.path.join(self.opts['pki_dir'], 'master.pem') - log.debug("Signing data packet") - payload['sig'] = salt.crypt.sign_message(master_pem_path, payload['load']) - # Send 0MQ to the publisher - context = zmq.Context(1) - pub_sock = context.socket(zmq.PUSH) + def pub_connect(self): + ''' + Create and connect this thread's zmq socket. If a publisher socket + already exists "pub_close" is called before creating and connecting a + new socket. + ''' + if self.pub_sock: + self.pub_close() + ctx = zmq.Context.instance() + self._sock_data.sock = ctx.socket(zmq.PUSH) + self.pub_sock.setsockopt(zmq.LINGER, -1) if self.opts.get('ipc_mode', '') == 'tcp': pull_uri = 'tcp://127.0.0.1:{0}'.format( self.opts.get('tcp_master_publish_pull', 4514) @@ -883,7 +894,32 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel): pull_uri = 'ipc://{0}'.format( os.path.join(self.opts['sock_dir'], 'publish_pull.ipc') ) - pub_sock.connect(pull_uri) + self.pub_sock.connect(pull_uri) + return self._sock_data.sock + + def pub_close(self): + ''' + Disconnect an existing publisher socket and remove it from the local + thread's cache. + ''' + if hasattr(self._sock_data, 'sock'): + self._sock_data.sock.close() + delattr(self._sock_data, 'sock') + + def publish(self, load): + ''' + Publish "load" to minions. This send the load to the publisher daemon + process with does the actual sending to minions. + + :param dict load: A load to be sent across the wire to minions + ''' + payload = {'enc': 'aes'} + crypticle = salt.crypt.Crypticle(self.opts, salt.master.SMaster.secrets['aes']['secret'].value) + payload['load'] = crypticle.dumps(load) + if self.opts['sign_pub_messages']: + master_pem_path = os.path.join(self.opts['pki_dir'], 'master.pem') + log.debug("Signing data packet") + payload['sig'] = salt.crypt.sign_message(master_pem_path, payload['load']) int_payload = {'payload': self.serial.dumps(payload)} # add some targeting stuff for lists only (for now) @@ -905,12 +941,11 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel): 'Sending payload to publish daemon. jid=%s size=%d', load.get('jid', None), len(payload), ) - pub_sock.send(payload) + if not self.pub_sock: + self.pub_connect() + self.pub_sock.send(payload) log.debug('Sent payload to publish daemon.') - pub_sock.close() - context.term() - class AsyncReqMessageClientPool(salt.transport.MessageClientPool): ''' diff --git a/tests/unit/transport/test_zeromq.py b/tests/unit/transport/test_zeromq.py index c604a0de19..814ef39844 100644 --- a/tests/unit/transport/test_zeromq.py +++ b/tests/unit/transport/test_zeromq.py @@ -8,6 +8,9 @@ from __future__ import absolute_import import os import time import threading +import multiprocessing +import ctypes +from concurrent.futures.thread import ThreadPoolExecutor # linux_distribution deprecated in py3.7 try: @@ -25,6 +28,7 @@ import tornado.gen # Import Salt libs import salt.config +import salt.log.setup import salt.ext.six as six import salt.utils import salt.transport.server @@ -305,3 +309,221 @@ class AsyncReqMessageClientPoolTest(TestCase): def test_destroy(self): self.message_client_pool.destroy() self.assertEqual([], self.message_client_pool.message_clients) + + +class PubServerChannel(TestCase, AdaptedConfigurationTestCaseMixin): + + @classmethod + def setUpClass(cls): + ret_port = get_unused_localhost_port() + publish_port = get_unused_localhost_port() + tcp_master_pub_port = get_unused_localhost_port() + tcp_master_pull_port = get_unused_localhost_port() + tcp_master_publish_pull = get_unused_localhost_port() + tcp_master_workers = get_unused_localhost_port() + cls.master_config = cls.get_temp_config( + 'master', + **{'transport': 'zeromq', + 'auto_accept': True, + 'ret_port': ret_port, + 'publish_port': publish_port, + 'tcp_master_pub_port': tcp_master_pub_port, + 'tcp_master_pull_port': tcp_master_pull_port, + 'tcp_master_publish_pull': tcp_master_publish_pull, + 'tcp_master_workers': tcp_master_workers, + 'sign_pub_messages': False, + } + ) + salt.master.SMaster.secrets['aes'] = { + 'secret': multiprocessing.Array( + ctypes.c_char, + six.b(salt.crypt.Crypticle.generate_key_string()), + ), + } + cls.minion_config = cls.get_temp_config( + 'minion', + **{'transport': 'zeromq', + 'master_ip': '127.0.0.1', + 'master_port': ret_port, + 'auth_timeout': 5, + 'auth_tries': 1, + 'master_uri': 'tcp://127.0.0.1:{0}'.format(ret_port)} + ) + + @classmethod + def tearDownClass(cls): + del cls.minion_config + del cls.master_config + + def setUp(self): + # Start the event loop, even though we dont directly use this with + # ZeroMQPubServerChannel, having it running seems to increase the + # likely hood of dropped messages. + self.io_loop = zmq.eventloop.ioloop.ZMQIOLoop() + self.io_loop.make_current() + self.io_loop_thread = threading.Thread(target=self.io_loop.start) + self.io_loop_thread.start() + self.process_manager = salt.utils.process.ProcessManager(name='PubServer_ProcessManager') + + def tearDown(self): + self.io_loop.add_callback(self.io_loop.stop) + self.io_loop_thread.join() + self.process_manager.stop_restarting() + self.process_manager.kill_children() + del self.io_loop + del self.io_loop_thread + del self.process_manager + + @staticmethod + def _gather_results(opts, pub_uri, results, timeout=120): + ''' + Gather results until then number of seconds specified by timeout passes + without reveiving a message + ''' + ctx = zmq.Context() + sock = ctx.socket(zmq.SUB) + sock.setsockopt(zmq.LINGER, -1) + sock.subscribe(b'') + sock.connect(pub_uri) + last_msg = time.time() + serial = salt.payload.Serial(opts) + crypticle = salt.crypt.Crypticle(opts, salt.master.SMaster.secrets['aes']['secret'].value) + while time.time() - last_msg < timeout: + try: + payload = sock.recv(zmq.NOBLOCK) + except zmq.ZMQError: + time.sleep(.01) + else: + payload = crypticle.loads(serial.loads(payload)['load']) + if 'stop' in payload: + break + last_msg = time.time() + results.append(payload['jid']) + return results + + @skipIf(salt.utils.is_windows(), 'Skip on Windows OS') + def test_publish_to_pubserv_ipc(self): + ''' + Test sending 10K messags to ZeroMQPubServerChannel using IPC transport + + ZMQ's ipc transport not supported on Windows + ''' + opts = dict(self.master_config, ipc_mode='ipc', pub_hwm=0) + server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts) + server_channel.pre_fork(self.process_manager, kwargs={ + 'log_queue': salt.log.setup.get_multiprocessing_logging_queue() + }) + pub_uri = 'tcp://{interface}:{publish_port}'.format(**server_channel.opts) + send_num = 10000 + expect = [] + results = [] + gather = threading.Thread(target=self._gather_results, args=(self.minion_config, pub_uri, results,)) + gather.start() + # Allow time for server channel to start, especially on windows + time.sleep(2) + for i in range(send_num): + expect.append(i) + load = {'tgt_type': 'glob', 'tgt': '*', 'jid': i} + server_channel.publish(load) + server_channel.publish( + {'tgt_type': 'glob', 'tgt': '*', 'stop': True} + ) + gather.join() + server_channel.pub_close() + assert len(results) == send_num, (len(results), set(expect).difference(results)) + + def test_publish_to_pubserv_tcp(self): + ''' + Test sending 10K messags to ZeroMQPubServerChannel using TCP transport + ''' + opts = dict(self.master_config, ipc_mode='tcp', pub_hwm=0) + server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts) + server_channel.pre_fork(self.process_manager, kwargs={ + 'log_queue': salt.log.setup.get_multiprocessing_logging_queue() + }) + pub_uri = 'tcp://{interface}:{publish_port}'.format(**server_channel.opts) + send_num = 10000 + expect = [] + results = [] + gather = threading.Thread(target=self._gather_results, args=(self.minion_config, pub_uri, results,)) + gather.start() + # Allow time for server channel to start, especially on windows + time.sleep(2) + for i in range(send_num): + expect.append(i) + load = {'tgt_type': 'glob', 'tgt': '*', 'jid': i} + server_channel.publish(load) + gather.join() + server_channel.pub_close() + assert len(results) == send_num, (len(results), set(expect).difference(results)) + + @staticmethod + def _send_small(opts, sid, num=10): + server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts) + for i in range(num): + load = {'tgt_type': 'glob', 'tgt': '*', 'jid': '{}-{}'.format(sid, i)} + server_channel.publish(load) + + @staticmethod + def _send_large(opts, sid, num=10, size=250000 * 3): + server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts) + for i in range(num): + load = {'tgt_type': 'glob', 'tgt': '*', 'jid': '{}-{}'.format(sid, i), 'xdata': '0' * size} + server_channel.publish(load) + + def test_issue_36469_tcp(self): + ''' + Test sending both large and small messags to publisher using TCP + ''' + opts = dict(self.master_config, ipc_mode='tcp', pub_hwm=0) + server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts) + server_channel.pre_fork(self.process_manager, kwargs={ + 'log_queue': salt.log.setup.get_multiprocessing_logging_queue() + }) + send_num = 10 * 4 + expect = [] + results = [] + pub_uri = 'tcp://{interface}:{publish_port}'.format(**opts) + # Allow time for server channel to start, especially on windows + time.sleep(2) + gather = threading.Thread(target=self._gather_results, args=(self.minion_config, pub_uri, results,)) + gather.start() + with ThreadPoolExecutor(max_workers=4) as executor: + executor.submit(self._send_small, opts, 1) + executor.submit(self._send_small, opts, 2) + executor.submit(self._send_small, opts, 3) + executor.submit(self._send_large, opts, 4) + expect = ['{}-{}'.format(a, b) for a in range(10) for b in (1, 2, 3, 4)] + server_channel.publish({'tgt_type': 'glob', 'tgt': '*', 'stop': True}) + gather.join() + server_channel.pub_close() + assert len(results) == send_num, (len(results), set(expect).difference(results)) + + @skipIf(salt.utils.is_windows(), 'Skip on Windows OS') + def test_issue_36469_udp(self): + ''' + Test sending both large and small messags to publisher using UDP + ''' + opts = dict(self.master_config, ipc_mode='udp', pub_hwm=0) + server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts) + server_channel.pre_fork(self.process_manager, kwargs={ + 'log_queue': salt.log.setup.get_multiprocessing_logging_queue() + }) + send_num = 10 * 4 + expect = [] + results = [] + pub_uri = 'tcp://{interface}:{publish_port}'.format(**opts) + # Allow time for server channel to start, especially on windows + time.sleep(2) + gather = threading.Thread(target=self._gather_results, args=(self.minion_config, pub_uri, results,)) + gather.start() + with ThreadPoolExecutor(max_workers=4) as executor: + executor.submit(self._send_small, opts, 1) + executor.submit(self._send_small, opts, 2) + executor.submit(self._send_small, opts, 3) + executor.submit(self._send_large, opts, 4) + expect = ['{}-{}'.format(a, b) for a in range(10) for b in (1, 2, 3, 4)] + server_channel.publish({'tgt_type': 'glob', 'tgt': '*', 'stop': True}) + gather.join() + server_channel.pub_close() + assert len(results) == send_num, (len(results), set(expect).difference(results))