mirror of
https://github.com/valitydev/salt.git
synced 2024-11-07 00:55:19 +00:00
3cc84ecbf7
Fixes #26850
356 lines
12 KiB
Python
356 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
'''
|
|
Many aspects of the salt payload need to be managed, from the return of
|
|
encrypted keys to general payload dynamics and packaging, these happen
|
|
in here
|
|
'''
|
|
|
|
# Import python libs
|
|
from __future__ import absolute_import
|
|
# import sys # Use if sys is commented out below
|
|
import logging
|
|
import gc
|
|
import datetime
|
|
|
|
# Import salt libs
|
|
import salt.log
|
|
import salt.crypt
|
|
from salt.exceptions import SaltReqTimeoutError
|
|
|
|
# Import third party libs
|
|
import salt.ext.six as six
|
|
try:
|
|
import zmq
|
|
except ImportError:
|
|
# No need for zeromq in local mode
|
|
pass
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
HAS_MSGPACK = False
|
|
try:
|
|
# Attempt to import msgpack
|
|
import msgpack
|
|
# There is a serialization issue on ARM and potentially other platforms
|
|
# for some msgpack bindings, check for it
|
|
if msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None:
|
|
raise ImportError
|
|
HAS_MSGPACK = True
|
|
except ImportError:
|
|
# Fall back to msgpack_pure
|
|
try:
|
|
import msgpack_pure as msgpack # pylint: disable=import-error
|
|
HAS_MSGPACK = True
|
|
except ImportError:
|
|
# TODO: Come up with a sane way to get a configured logfile
|
|
# and write to the logfile when this error is hit also
|
|
LOG_FORMAT = '[%(levelname)-8s] %(message)s'
|
|
salt.log.setup_console_logger(log_format=LOG_FORMAT)
|
|
log.fatal('Unable to import msgpack or msgpack_pure python modules')
|
|
# Don't exit if msgpack is not available, this is to make local mode
|
|
# work without msgpack
|
|
#sys.exit(salt.defaults.exitcodes.EX_GENERIC)
|
|
|
|
|
|
if HAS_MSGPACK and not hasattr(msgpack, 'exceptions'):
|
|
class PackValueError(Exception):
|
|
'''
|
|
older versions of msgpack do not have PackValueError
|
|
'''
|
|
|
|
class exceptions(object):
|
|
'''
|
|
older versions of msgpack do not have an exceptions module
|
|
'''
|
|
PackValueError = PackValueError()
|
|
|
|
msgpack.exceptions = exceptions()
|
|
|
|
|
|
def package(payload):
|
|
'''
|
|
This method for now just wraps msgpack.dumps, but it is here so that
|
|
we can make the serialization a custom option in the future with ease.
|
|
'''
|
|
return msgpack.dumps(payload)
|
|
|
|
|
|
def unpackage(package_):
|
|
'''
|
|
Unpackages a payload
|
|
'''
|
|
return msgpack.loads(package_, use_list=True)
|
|
|
|
|
|
def format_payload(enc, **kwargs):
|
|
'''
|
|
Pass in the required arguments for a payload, the enc type and the cmd,
|
|
then a list of keyword args to generate the body of the load dict.
|
|
'''
|
|
payload = {'enc': enc}
|
|
load = {}
|
|
for key in kwargs:
|
|
load[key] = kwargs[key]
|
|
payload['load'] = load
|
|
return package(payload)
|
|
|
|
|
|
class Serial(object):
|
|
'''
|
|
Create a serialization object, this object manages all message
|
|
serialization in Salt
|
|
'''
|
|
def __init__(self, opts):
|
|
if isinstance(opts, dict):
|
|
self.serial = opts.get('serial', 'msgpack')
|
|
elif isinstance(opts, str):
|
|
self.serial = opts
|
|
else:
|
|
self.serial = 'msgpack'
|
|
|
|
def loads(self, msg):
|
|
'''
|
|
Run the correct loads serialization format
|
|
'''
|
|
try:
|
|
gc.disable() # performance optimization for msgpack
|
|
return msgpack.loads(msg, use_list=True)
|
|
except Exception as exc:
|
|
log.critical('Could not deserialize msgpack message: {0}'
|
|
'This often happens when trying to read a file not in binary mode.'
|
|
'Please open an issue and include the following error: {1}'.format(msg, exc))
|
|
raise
|
|
finally:
|
|
gc.enable()
|
|
|
|
def load(self, fn_):
|
|
'''
|
|
Run the correct serialization to load a file
|
|
'''
|
|
data = fn_.read()
|
|
fn_.close()
|
|
if data:
|
|
return self.loads(data)
|
|
|
|
def dumps(self, msg):
|
|
'''
|
|
Run the correct dumps serialization format
|
|
'''
|
|
try:
|
|
return msgpack.dumps(msg)
|
|
except (OverflowError, msgpack.exceptions.PackValueError):
|
|
# msgpack can't handle the very long Python longs for jids
|
|
# Convert any very long longs to strings
|
|
# We borrow the technique used by TypeError below
|
|
def verylong_encoder(obj):
|
|
if isinstance(obj, dict):
|
|
for key, value in six.iteritems(obj.copy()):
|
|
obj[key] = verylong_encoder(value)
|
|
return dict(obj)
|
|
elif isinstance(obj, (list, tuple)):
|
|
obj = list(obj)
|
|
for idx, entry in enumerate(obj):
|
|
obj[idx] = verylong_encoder(entry)
|
|
return obj
|
|
if six.PY2 and isinstance(obj, long) and long > pow(2, 64): # pylint: disable=incompatible-py3-code
|
|
return str(obj)
|
|
elif six.PY3 and isinstance(obj, int) and int > pow(2, 64):
|
|
return str(obj)
|
|
else:
|
|
return obj
|
|
return msgpack.dumps(verylong_encoder(msg))
|
|
except TypeError as e:
|
|
# msgpack doesn't support datetime.datetime datatype
|
|
# So here we have converted datetime.datetime to custom datatype
|
|
# This is msgpack Extended types numbered 78
|
|
def default(obj):
|
|
return msgpack.ExtType(78, obj)
|
|
|
|
def dt_encode(obj):
|
|
datetime_str = obj.strftime("%Y%m%dT%H:%M:%S.%f")
|
|
return msgpack.packb(datetime_str, default=default)
|
|
|
|
def datetime_encoder(obj):
|
|
if isinstance(obj, dict):
|
|
for key, value in six.iteritems(obj.copy()):
|
|
obj[key] = datetime_encoder(value)
|
|
return dict(obj)
|
|
elif isinstance(obj, (list, tuple)):
|
|
obj = list(obj)
|
|
for idx, entry in enumerate(obj):
|
|
obj[idx] = datetime_encoder(entry)
|
|
return obj
|
|
if isinstance(obj, datetime.datetime):
|
|
return dt_encode(obj)
|
|
else:
|
|
return obj
|
|
|
|
if "datetime.datetime" in str(e):
|
|
return msgpack.dumps(datetime_encoder(msg))
|
|
|
|
if msgpack.version >= (0, 2, 0):
|
|
# Should support OrderedDict serialization, so, let's
|
|
# raise the exception
|
|
raise
|
|
|
|
# msgpack is < 0.2.0, let's make its life easier
|
|
# Since OrderedDict is identified as a dictionary, we can't
|
|
# make use of msgpack custom types, we will need to convert by
|
|
# hand.
|
|
# This means iterating through all elements of a dictionary or
|
|
# list/tuple
|
|
def odict_encoder(obj):
|
|
if isinstance(obj, dict):
|
|
for key, value in six.iteritems(obj.copy()):
|
|
obj[key] = odict_encoder(value)
|
|
return dict(obj)
|
|
elif isinstance(obj, (list, tuple)):
|
|
obj = list(obj)
|
|
for idx, entry in enumerate(obj):
|
|
obj[idx] = odict_encoder(entry)
|
|
return obj
|
|
return obj
|
|
return msgpack.dumps(odict_encoder(msg))
|
|
except (SystemError, TypeError) as exc:
|
|
log.critical('Unable to serialize message! Consider upgrading msgpack. '
|
|
'Message which failed was {failed_message} '
|
|
'with exception {exception_message}').format(msg, exc)
|
|
|
|
def dump(self, msg, fn_):
|
|
'''
|
|
Serialize the correct data into the named file object
|
|
'''
|
|
fn_.write(self.dumps(msg))
|
|
fn_.close()
|
|
|
|
|
|
class SREQ(object):
|
|
'''
|
|
Create a generic interface to wrap salt zeromq req calls.
|
|
'''
|
|
def __init__(self, master, id_='', serial='msgpack', linger=0, opts=None):
|
|
self.master = master
|
|
self.id_ = id_
|
|
self.serial = Serial(serial)
|
|
self.linger = linger
|
|
self.context = zmq.Context()
|
|
self.poller = zmq.Poller()
|
|
self.opts = opts
|
|
|
|
@property
|
|
def socket(self):
|
|
'''
|
|
Lazily create the socket.
|
|
'''
|
|
if not hasattr(self, '_socket'):
|
|
# create a new one
|
|
self._socket = self.context.socket(zmq.REQ)
|
|
if hasattr(zmq, 'RECONNECT_IVL_MAX'):
|
|
self._socket.setsockopt(
|
|
zmq.RECONNECT_IVL_MAX, 5000
|
|
)
|
|
|
|
self._set_tcp_keepalive()
|
|
if self.master.startswith('tcp://['):
|
|
# Hint PF type if bracket enclosed IPv6 address
|
|
if hasattr(zmq, 'IPV6'):
|
|
self._socket.setsockopt(zmq.IPV6, 1)
|
|
elif hasattr(zmq, 'IPV4ONLY'):
|
|
self._socket.setsockopt(zmq.IPV4ONLY, 0)
|
|
self._socket.linger = self.linger
|
|
if self.id_:
|
|
self._socket.setsockopt(zmq.IDENTITY, self.id_)
|
|
self._socket.connect(self.master)
|
|
return self._socket
|
|
|
|
def _set_tcp_keepalive(self):
|
|
if hasattr(zmq, 'TCP_KEEPALIVE') and self.opts:
|
|
if 'tcp_keepalive' in self.opts:
|
|
self._socket.setsockopt(
|
|
zmq.TCP_KEEPALIVE, self.opts['tcp_keepalive']
|
|
)
|
|
if 'tcp_keepalive_idle' in self.opts:
|
|
self._socket.setsockopt(
|
|
zmq.TCP_KEEPALIVE_IDLE, self.opts['tcp_keepalive_idle']
|
|
)
|
|
if 'tcp_keepalive_cnt' in self.opts:
|
|
self._socket.setsockopt(
|
|
zmq.TCP_KEEPALIVE_CNT, self.opts['tcp_keepalive_cnt']
|
|
)
|
|
if 'tcp_keepalive_intvl' in self.opts:
|
|
self._socket.setsockopt(
|
|
zmq.TCP_KEEPALIVE_INTVL, self.opts['tcp_keepalive_intvl']
|
|
)
|
|
|
|
def clear_socket(self):
|
|
'''
|
|
delete socket if you have it
|
|
'''
|
|
if hasattr(self, '_socket'):
|
|
if isinstance(self.poller.sockets, dict):
|
|
sockets = list(self.poller.sockets.keys())
|
|
for socket in sockets:
|
|
log.trace('Unregistering socket: {0}'.format(socket))
|
|
self.poller.unregister(socket)
|
|
else:
|
|
for socket in self.poller.sockets:
|
|
log.trace('Unregistering socket: {0}'.format(socket))
|
|
self.poller.unregister(socket[0])
|
|
del self._socket
|
|
|
|
def send(self, enc, load, tries=1, timeout=60):
|
|
'''
|
|
Takes two arguments, the encryption type and the base payload
|
|
'''
|
|
payload = {'enc': enc}
|
|
payload['load'] = load
|
|
pkg = self.serial.dumps(payload)
|
|
self.socket.send(pkg)
|
|
self.poller.register(self.socket, zmq.POLLIN)
|
|
tried = 0
|
|
while True:
|
|
polled = self.poller.poll(timeout * 1000)
|
|
tried += 1
|
|
if polled:
|
|
break
|
|
if tries > 1:
|
|
log.info('SaltReqTimeoutError: after {0} seconds. (Try {1} of {2})'.format(
|
|
timeout, tried, tries))
|
|
if tried >= tries:
|
|
self.clear_socket()
|
|
raise SaltReqTimeoutError(
|
|
'SaltReqTimeoutError: after {0} seconds, ran {1} tries'.format(timeout * tried, tried)
|
|
)
|
|
return self.serial.loads(self.socket.recv())
|
|
|
|
def send_auto(self, payload, tries=1, timeout=60):
|
|
'''
|
|
Detect the encryption type based on the payload
|
|
'''
|
|
enc = payload.get('enc', 'clear')
|
|
load = payload.get('load', {})
|
|
return self.send(enc, load, tries, timeout)
|
|
|
|
def destroy(self):
|
|
if isinstance(self.poller.sockets, dict):
|
|
sockets = list(self.poller.sockets.keys())
|
|
for socket in sockets:
|
|
if socket.closed is False:
|
|
socket.setsockopt(zmq.LINGER, 1)
|
|
socket.close()
|
|
self.poller.unregister(socket)
|
|
else:
|
|
for socket in self.poller.sockets:
|
|
if socket[0].closed is False:
|
|
socket[0].setsockopt(zmq.LINGER, 1)
|
|
socket[0].close()
|
|
self.poller.unregister(socket[0])
|
|
if self.socket.closed is False:
|
|
self.socket.setsockopt(zmq.LINGER, 1)
|
|
self.socket.close()
|
|
if self.context.closed is False:
|
|
self.context.term()
|
|
|
|
def __del__(self):
|
|
self.destroy()
|