More reliable pub server publishing

- Uses existing zmq context
- Store socket in thread local storage for re-use
- Add tests to verify publish reliability
This commit is contained in:
Daniel A. Wozniak 2018-11-11 00:30:23 -07:00
parent c85561ee56
commit 6882209e3b
No known key found for this signature in database
GPG Key ID: 166B9D2C06C82D61
2 changed files with 285 additions and 28 deletions

View File

@ -13,6 +13,7 @@ import signal
import hashlib
import logging
import weakref
import threading
from random import randint
# Import Salt Libs
@ -738,6 +739,9 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel):
'''
Encapsulate synchronous operations for a publisher channel
'''
_sock_data = threading.local()
def __init__(self, opts):
self.opts = opts
self.serial = salt.payload.Serial(self.opts) # TODO: in init?
@ -773,9 +777,11 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel):
# IPv6 sockets work for both IPv6 and IPv4 addresses
pub_sock.setsockopt(zmq.IPV4ONLY, 0)
pub_sock.setsockopt(zmq.BACKLOG, self.opts.get('zmq_backlog', 1000))
pub_sock.setsockopt(zmq.LINGER, -1)
pub_uri = 'tcp://{interface}:{publish_port}'.format(**self.opts)
# Prepare minion pull socket
pull_sock = context.socket(zmq.PULL)
pull_sock.setsockopt(zmq.LINGER, -1)
if self.opts.get('ipc_mode', '') == 'tcp':
pull_uri = 'tcp://127.0.0.1:{0}'.format(
@ -838,12 +844,11 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel):
raise exc
except KeyboardInterrupt:
log.trace('Publish daemon caught Keyboard interupt, tearing down')
# Cleanly close the sockets if we're shutting down
if pub_sock.closed is False:
pub_sock.setsockopt(zmq.LINGER, 1)
pub_sock.close()
if pull_sock.closed is False:
pull_sock.setsockopt(zmq.LINGER, 1)
pull_sock.close()
if context.closed is False:
context.term()
@ -858,23 +863,29 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel):
'''
process_manager.add_process(self._publish_daemon, kwargs=kwargs)
def publish(self, load):
@property
def pub_sock(self):
'''
Publish "load" to minions
:param dict load: A load to be sent across the wire to minions
This thread's zmq publisher socket. This socket is stored on the class
so that multiple instantiations in the same thread will re-use a single
zmq socket.
'''
payload = {'enc': 'aes'}
try:
return self._sock_data.sock
except AttributeError:
pass
crypticle = salt.crypt.Crypticle(self.opts, salt.master.SMaster.secrets['aes']['secret'].value)
payload['load'] = crypticle.dumps(load)
if self.opts['sign_pub_messages']:
master_pem_path = os.path.join(self.opts['pki_dir'], 'master.pem')
log.debug("Signing data packet")
payload['sig'] = salt.crypt.sign_message(master_pem_path, payload['load'])
# Send 0MQ to the publisher
context = zmq.Context(1)
pub_sock = context.socket(zmq.PUSH)
def pub_connect(self):
'''
Create and connect this thread's zmq socket. If a publisher socket
already exists "pub_close" is called before creating and connecting a
new socket.
'''
if self.pub_sock:
self.pub_close()
ctx = zmq.Context.instance()
self._sock_data.sock = ctx.socket(zmq.PUSH)
self.pub_sock.setsockopt(zmq.LINGER, -1)
if self.opts.get('ipc_mode', '') == 'tcp':
pull_uri = 'tcp://127.0.0.1:{0}'.format(
self.opts.get('tcp_master_publish_pull', 4514)
@ -883,7 +894,32 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel):
pull_uri = 'ipc://{0}'.format(
os.path.join(self.opts['sock_dir'], 'publish_pull.ipc')
)
pub_sock.connect(pull_uri)
self.pub_sock.connect(pull_uri)
return self._sock_data.sock
def pub_close(self):
'''
Disconnect an existing publisher socket and remove it from the local
thread's cache.
'''
if hasattr(self._sock_data, 'sock'):
self._sock_data.sock.close()
delattr(self._sock_data, 'sock')
def publish(self, load):
'''
Publish "load" to minions. This send the load to the publisher daemon
process with does the actual sending to minions.
:param dict load: A load to be sent across the wire to minions
'''
payload = {'enc': 'aes'}
crypticle = salt.crypt.Crypticle(self.opts, salt.master.SMaster.secrets['aes']['secret'].value)
payload['load'] = crypticle.dumps(load)
if self.opts['sign_pub_messages']:
master_pem_path = os.path.join(self.opts['pki_dir'], 'master.pem')
log.debug("Signing data packet")
payload['sig'] = salt.crypt.sign_message(master_pem_path, payload['load'])
int_payload = {'payload': self.serial.dumps(payload)}
# add some targeting stuff for lists only (for now)
@ -905,12 +941,11 @@ class ZeroMQPubServerChannel(salt.transport.server.PubServerChannel):
'Sending payload to publish daemon. jid=%s size=%d',
load.get('jid', None), len(payload),
)
pub_sock.send(payload)
if not self.pub_sock:
self.pub_connect()
self.pub_sock.send(payload)
log.debug('Sent payload to publish daemon.')
pub_sock.close()
context.term()
class AsyncReqMessageClientPool(salt.transport.MessageClientPool):
'''

View File

@ -8,6 +8,9 @@ from __future__ import absolute_import
import os
import time
import threading
import multiprocessing
import ctypes
from concurrent.futures.thread import ThreadPoolExecutor
# linux_distribution deprecated in py3.7
try:
@ -25,6 +28,7 @@ import tornado.gen
# Import Salt libs
import salt.config
import salt.log.setup
import salt.ext.six as six
import salt.utils
import salt.transport.server
@ -305,3 +309,221 @@ class AsyncReqMessageClientPoolTest(TestCase):
def test_destroy(self):
self.message_client_pool.destroy()
self.assertEqual([], self.message_client_pool.message_clients)
class PubServerChannel(TestCase, AdaptedConfigurationTestCaseMixin):
@classmethod
def setUpClass(cls):
ret_port = get_unused_localhost_port()
publish_port = get_unused_localhost_port()
tcp_master_pub_port = get_unused_localhost_port()
tcp_master_pull_port = get_unused_localhost_port()
tcp_master_publish_pull = get_unused_localhost_port()
tcp_master_workers = get_unused_localhost_port()
cls.master_config = cls.get_temp_config(
'master',
**{'transport': 'zeromq',
'auto_accept': True,
'ret_port': ret_port,
'publish_port': publish_port,
'tcp_master_pub_port': tcp_master_pub_port,
'tcp_master_pull_port': tcp_master_pull_port,
'tcp_master_publish_pull': tcp_master_publish_pull,
'tcp_master_workers': tcp_master_workers,
'sign_pub_messages': False,
}
)
salt.master.SMaster.secrets['aes'] = {
'secret': multiprocessing.Array(
ctypes.c_char,
six.b(salt.crypt.Crypticle.generate_key_string()),
),
}
cls.minion_config = cls.get_temp_config(
'minion',
**{'transport': 'zeromq',
'master_ip': '127.0.0.1',
'master_port': ret_port,
'auth_timeout': 5,
'auth_tries': 1,
'master_uri': 'tcp://127.0.0.1:{0}'.format(ret_port)}
)
@classmethod
def tearDownClass(cls):
del cls.minion_config
del cls.master_config
def setUp(self):
# Start the event loop, even though we dont directly use this with
# ZeroMQPubServerChannel, having it running seems to increase the
# likely hood of dropped messages.
self.io_loop = zmq.eventloop.ioloop.ZMQIOLoop()
self.io_loop.make_current()
self.io_loop_thread = threading.Thread(target=self.io_loop.start)
self.io_loop_thread.start()
self.process_manager = salt.utils.process.ProcessManager(name='PubServer_ProcessManager')
def tearDown(self):
self.io_loop.add_callback(self.io_loop.stop)
self.io_loop_thread.join()
self.process_manager.stop_restarting()
self.process_manager.kill_children()
del self.io_loop
del self.io_loop_thread
del self.process_manager
@staticmethod
def _gather_results(opts, pub_uri, results, timeout=120):
'''
Gather results until then number of seconds specified by timeout passes
without reveiving a message
'''
ctx = zmq.Context()
sock = ctx.socket(zmq.SUB)
sock.setsockopt(zmq.LINGER, -1)
sock.subscribe(b'')
sock.connect(pub_uri)
last_msg = time.time()
serial = salt.payload.Serial(opts)
crypticle = salt.crypt.Crypticle(opts, salt.master.SMaster.secrets['aes']['secret'].value)
while time.time() - last_msg < timeout:
try:
payload = sock.recv(zmq.NOBLOCK)
except zmq.ZMQError:
time.sleep(.01)
else:
payload = crypticle.loads(serial.loads(payload)['load'])
if 'stop' in payload:
break
last_msg = time.time()
results.append(payload['jid'])
return results
@skipIf(salt.utils.is_windows(), 'Skip on Windows OS')
def test_publish_to_pubserv_ipc(self):
'''
Test sending 10K messags to ZeroMQPubServerChannel using IPC transport
ZMQ's ipc transport not supported on Windows
'''
opts = dict(self.master_config, ipc_mode='ipc', pub_hwm=0)
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel.pre_fork(self.process_manager, kwargs={
'log_queue': salt.log.setup.get_multiprocessing_logging_queue()
})
pub_uri = 'tcp://{interface}:{publish_port}'.format(**server_channel.opts)
send_num = 10000
expect = []
results = []
gather = threading.Thread(target=self._gather_results, args=(self.minion_config, pub_uri, results,))
gather.start()
# Allow time for server channel to start, especially on windows
time.sleep(2)
for i in range(send_num):
expect.append(i)
load = {'tgt_type': 'glob', 'tgt': '*', 'jid': i}
server_channel.publish(load)
server_channel.publish(
{'tgt_type': 'glob', 'tgt': '*', 'stop': True}
)
gather.join()
server_channel.pub_close()
assert len(results) == send_num, (len(results), set(expect).difference(results))
def test_publish_to_pubserv_tcp(self):
'''
Test sending 10K messags to ZeroMQPubServerChannel using TCP transport
'''
opts = dict(self.master_config, ipc_mode='tcp', pub_hwm=0)
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel.pre_fork(self.process_manager, kwargs={
'log_queue': salt.log.setup.get_multiprocessing_logging_queue()
})
pub_uri = 'tcp://{interface}:{publish_port}'.format(**server_channel.opts)
send_num = 10000
expect = []
results = []
gather = threading.Thread(target=self._gather_results, args=(self.minion_config, pub_uri, results,))
gather.start()
# Allow time for server channel to start, especially on windows
time.sleep(2)
for i in range(send_num):
expect.append(i)
load = {'tgt_type': 'glob', 'tgt': '*', 'jid': i}
server_channel.publish(load)
gather.join()
server_channel.pub_close()
assert len(results) == send_num, (len(results), set(expect).difference(results))
@staticmethod
def _send_small(opts, sid, num=10):
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
for i in range(num):
load = {'tgt_type': 'glob', 'tgt': '*', 'jid': '{}-{}'.format(sid, i)}
server_channel.publish(load)
@staticmethod
def _send_large(opts, sid, num=10, size=250000 * 3):
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
for i in range(num):
load = {'tgt_type': 'glob', 'tgt': '*', 'jid': '{}-{}'.format(sid, i), 'xdata': '0' * size}
server_channel.publish(load)
def test_issue_36469_tcp(self):
'''
Test sending both large and small messags to publisher using TCP
'''
opts = dict(self.master_config, ipc_mode='tcp', pub_hwm=0)
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel.pre_fork(self.process_manager, kwargs={
'log_queue': salt.log.setup.get_multiprocessing_logging_queue()
})
send_num = 10 * 4
expect = []
results = []
pub_uri = 'tcp://{interface}:{publish_port}'.format(**opts)
# Allow time for server channel to start, especially on windows
time.sleep(2)
gather = threading.Thread(target=self._gather_results, args=(self.minion_config, pub_uri, results,))
gather.start()
with ThreadPoolExecutor(max_workers=4) as executor:
executor.submit(self._send_small, opts, 1)
executor.submit(self._send_small, opts, 2)
executor.submit(self._send_small, opts, 3)
executor.submit(self._send_large, opts, 4)
expect = ['{}-{}'.format(a, b) for a in range(10) for b in (1, 2, 3, 4)]
server_channel.publish({'tgt_type': 'glob', 'tgt': '*', 'stop': True})
gather.join()
server_channel.pub_close()
assert len(results) == send_num, (len(results), set(expect).difference(results))
@skipIf(salt.utils.is_windows(), 'Skip on Windows OS')
def test_issue_36469_udp(self):
'''
Test sending both large and small messags to publisher using UDP
'''
opts = dict(self.master_config, ipc_mode='udp', pub_hwm=0)
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
server_channel.pre_fork(self.process_manager, kwargs={
'log_queue': salt.log.setup.get_multiprocessing_logging_queue()
})
send_num = 10 * 4
expect = []
results = []
pub_uri = 'tcp://{interface}:{publish_port}'.format(**opts)
# Allow time for server channel to start, especially on windows
time.sleep(2)
gather = threading.Thread(target=self._gather_results, args=(self.minion_config, pub_uri, results,))
gather.start()
with ThreadPoolExecutor(max_workers=4) as executor:
executor.submit(self._send_small, opts, 1)
executor.submit(self._send_small, opts, 2)
executor.submit(self._send_small, opts, 3)
executor.submit(self._send_large, opts, 4)
expect = ['{}-{}'.format(a, b) for a in range(10) for b in (1, 2, 3, 4)]
server_channel.publish({'tgt_type': 'glob', 'tgt': '*', 'stop': True})
gather.join()
server_channel.pub_close()
assert len(results) == send_num, (len(results), set(expect).difference(results))