mirror of
https://github.com/valitydev/salt.git
synced 2024-11-07 08:58:59 +00:00
Merge pull request #46269 from DSRCorporation/bugs/46202_msgpack_custom_types_rework
msgpack custom types rework
This commit is contained in:
commit
00ca4f01c1
140
salt/payload.py
140
salt/payload.py
@ -17,6 +17,7 @@ import salt.log
|
||||
import salt.crypt
|
||||
import salt.transport.frame
|
||||
import salt.utils.immutabletypes as immutabletypes
|
||||
import salt.utils.stringutils
|
||||
from salt.exceptions import SaltReqTimeoutError
|
||||
|
||||
# Import third party libs
|
||||
@ -128,15 +129,21 @@ class Serial(object):
|
||||
the contents cannot be converted.
|
||||
'''
|
||||
try:
|
||||
def ext_type_decoder(code, data):
|
||||
if code == 78:
|
||||
data = salt.utils.stringutils.to_unicode(data)
|
||||
return datetime.datetime.strptime(data, '%Y%m%dT%H:%M:%S.%f')
|
||||
return data
|
||||
|
||||
gc.disable() # performance optimization for msgpack
|
||||
if msgpack.version >= (0, 4, 0):
|
||||
# msgpack only supports 'encoding' starting in 0.4.0.
|
||||
# Due to this, if we don't need it, don't pass it at all so
|
||||
# that under Python 2 we can still work with older versions
|
||||
# of msgpack.
|
||||
ret = msgpack.loads(msg, use_list=True, encoding=encoding)
|
||||
ret = msgpack.loads(msg, use_list=True, ext_hook=ext_type_decoder, encoding=encoding)
|
||||
else:
|
||||
ret = msgpack.loads(msg, use_list=True)
|
||||
ret = msgpack.loads(msg, use_list=True, ext_hook=ext_type_decoder)
|
||||
if six.PY3 and encoding is None and not raw:
|
||||
ret = salt.transport.frame.decode_embedded_strs(ret)
|
||||
except Exception as exc:
|
||||
@ -175,19 +182,40 @@ class Serial(object):
|
||||
Since this changes the wire protocol, this
|
||||
option should not be used outside of IPC.
|
||||
'''
|
||||
def ext_type_encoder(obj):
|
||||
if isinstance(obj, six.integer_types):
|
||||
# msgpack can't handle the very long Python longs for jids
|
||||
# Convert any very long longs to strings
|
||||
return six.text_type(obj)
|
||||
elif isinstance(obj, datetime.datetime):
|
||||
# msgpack doesn't support datetime.datetime datatype
|
||||
# So here we have converted datetime.datetime to custom datatype
|
||||
# This is msgpack Extended types numbered 78
|
||||
return msgpack.ExtType(78, salt.utils.stringutils.to_bytes(
|
||||
obj.strftime('%Y%m%dT%H:%M:%S.%f')))
|
||||
# The same for immutable types
|
||||
elif isinstance(obj, immutabletypes.ImmutableDict):
|
||||
return dict(obj)
|
||||
elif isinstance(obj, immutabletypes.ImmutableList):
|
||||
return list(obj)
|
||||
elif isinstance(obj, (set, immutabletypes.ImmutableSet)):
|
||||
# msgpack can't handle set so translate it to tuple
|
||||
return tuple(obj)
|
||||
# Nothing known exceptions found. Let msgpack raise it's own.
|
||||
return obj
|
||||
|
||||
try:
|
||||
if msgpack.version >= (0, 4, 0):
|
||||
# msgpack only supports 'use_bin_type' starting in 0.4.0.
|
||||
# Due to this, if we don't need it, don't pass it at all so
|
||||
# that under Python 2 we can still work with older versions
|
||||
# of msgpack.
|
||||
return msgpack.dumps(msg, use_bin_type=use_bin_type)
|
||||
return msgpack.dumps(msg, default=ext_type_encoder, use_bin_type=use_bin_type)
|
||||
else:
|
||||
return msgpack.dumps(msg)
|
||||
return msgpack.dumps(msg, default=ext_type_encoder)
|
||||
except (OverflowError, msgpack.exceptions.PackValueError):
|
||||
# msgpack can't handle the very long Python longs for jids
|
||||
# Convert any very long longs to strings
|
||||
# We borrow the technique used by TypeError below
|
||||
# msgpack<=0.4.6 don't call ext encoder on very long integers raising the error instead.
|
||||
# Convert any very long longs to strings and call dumps again.
|
||||
def verylong_encoder(obj):
|
||||
if isinstance(obj, dict):
|
||||
for key, value in six.iteritems(obj.copy()):
|
||||
@ -198,102 +226,18 @@ class Serial(object):
|
||||
for idx, entry in enumerate(obj):
|
||||
obj[idx] = verylong_encoder(entry)
|
||||
return obj
|
||||
# This is a spurious lint failure as we are gating this check
|
||||
# behind a check for six.PY2.
|
||||
if six.PY2 and isinstance(obj, long) and long > pow(2, 64): # pylint: disable=incompatible-py3-code
|
||||
return six.text_type(obj)
|
||||
elif six.PY3 and isinstance(obj, int) and int > pow(2, 64):
|
||||
# A value of an Integer object is limited from -(2^63) upto (2^64)-1 by MessagePack
|
||||
# spec. Here we care only of JIDs that are positive integers.
|
||||
if isinstance(obj, six.integer_types) and obj >= pow(2, 64):
|
||||
return six.text_type(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
msg = verylong_encoder(msg)
|
||||
if msgpack.version >= (0, 4, 0):
|
||||
return msgpack.dumps(verylong_encoder(msg), use_bin_type=use_bin_type)
|
||||
return msgpack.dumps(msg, default=ext_type_encoder, use_bin_type=use_bin_type)
|
||||
else:
|
||||
return msgpack.dumps(verylong_encoder(msg))
|
||||
except TypeError as e:
|
||||
# msgpack doesn't support datetime.datetime datatype
|
||||
# So here we have converted datetime.datetime to custom datatype
|
||||
# This is msgpack Extended types numbered 78
|
||||
def default(obj):
|
||||
return msgpack.ExtType(78, obj)
|
||||
|
||||
def dt_encode(obj):
|
||||
datetime_str = obj.strftime("%Y%m%dT%H:%M:%S.%f")
|
||||
if msgpack.version >= (0, 4, 0):
|
||||
return msgpack.packb(datetime_str, default=default, use_bin_type=use_bin_type)
|
||||
else:
|
||||
return msgpack.packb(datetime_str, default=default)
|
||||
|
||||
def datetime_encoder(obj):
|
||||
if isinstance(obj, dict):
|
||||
for key, value in six.iteritems(obj.copy()):
|
||||
encodedkey = datetime_encoder(key)
|
||||
if key != encodedkey:
|
||||
del obj[key]
|
||||
key = encodedkey
|
||||
obj[key] = datetime_encoder(value)
|
||||
return dict(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
obj = list(obj)
|
||||
for idx, entry in enumerate(obj):
|
||||
obj[idx] = datetime_encoder(entry)
|
||||
return obj
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return dt_encode(obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
def immutable_encoder(obj):
|
||||
log.debug('IMMUTABLE OBJ: %s', obj)
|
||||
if isinstance(obj, immutabletypes.ImmutableDict):
|
||||
return dict(obj)
|
||||
if isinstance(obj, immutabletypes.ImmutableList):
|
||||
return list(obj)
|
||||
if isinstance(obj, immutabletypes.ImmutableSet):
|
||||
return set(obj)
|
||||
|
||||
if "datetime.datetime" in six.text_type(e):
|
||||
if msgpack.version >= (0, 4, 0):
|
||||
return msgpack.dumps(datetime_encoder(msg), use_bin_type=use_bin_type)
|
||||
else:
|
||||
return msgpack.dumps(datetime_encoder(msg))
|
||||
elif "Immutable" in six.text_type(e):
|
||||
if msgpack.version >= (0, 4, 0):
|
||||
return msgpack.dumps(msg, default=immutable_encoder, use_bin_type=use_bin_type)
|
||||
else:
|
||||
return msgpack.dumps(msg, default=immutable_encoder)
|
||||
|
||||
if msgpack.version >= (0, 2, 0):
|
||||
# Should support OrderedDict serialization, so, let's
|
||||
# raise the exception
|
||||
raise
|
||||
|
||||
# msgpack is < 0.2.0, let's make its life easier
|
||||
# Since OrderedDict is identified as a dictionary, we can't
|
||||
# make use of msgpack custom types, we will need to convert by
|
||||
# hand.
|
||||
# This means iterating through all elements of a dictionary or
|
||||
# list/tuple
|
||||
def odict_encoder(obj):
|
||||
if isinstance(obj, dict):
|
||||
for key, value in six.iteritems(obj.copy()):
|
||||
obj[key] = odict_encoder(value)
|
||||
return dict(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
obj = list(obj)
|
||||
for idx, entry in enumerate(obj):
|
||||
obj[idx] = odict_encoder(entry)
|
||||
return obj
|
||||
return obj
|
||||
if msgpack.version >= (0, 4, 0):
|
||||
return msgpack.dumps(odict_encoder(msg), use_bin_type=use_bin_type)
|
||||
else:
|
||||
return msgpack.dumps(odict_encoder(msg))
|
||||
except (SystemError, TypeError) as exc: # pylint: disable=W0705
|
||||
log.critical(
|
||||
'Unable to serialize message! Consider upgrading msgpack. '
|
||||
'Message which failed was %s, with exception %s', msg, exc
|
||||
)
|
||||
return msgpack.dumps(msg, default=ext_type_encoder)
|
||||
|
||||
def dump(self, msg, fn_):
|
||||
'''
|
||||
|
@ -7,24 +7,24 @@
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
'''
|
||||
|
||||
# Import Salt libs
|
||||
# Import python libs
|
||||
from __future__ import absolute_import, print_function, unicode_literals
|
||||
import time
|
||||
import errno
|
||||
import threading
|
||||
import datetime
|
||||
|
||||
# Import Salt Testing libs
|
||||
from tests.support.unit import skipIf, TestCase
|
||||
from tests.support.helpers import MockWraps
|
||||
from tests.support.mock import NO_MOCK, NO_MOCK_REASON, patch
|
||||
from tests.support.mock import NO_MOCK, NO_MOCK_REASON
|
||||
|
||||
# Import salt libs
|
||||
import salt.payload
|
||||
# Import Salt libs
|
||||
from salt.utils import immutabletypes
|
||||
from salt.utils.odict import OrderedDict
|
||||
import salt.exceptions
|
||||
import salt.payload
|
||||
|
||||
# Import 3rd-party libs
|
||||
import msgpack
|
||||
import zmq
|
||||
from salt.ext import six
|
||||
|
||||
@ -49,16 +49,110 @@ class PayloadTestCase(TestCase):
|
||||
self.assertNoOrderedDict(chunk)
|
||||
|
||||
def test_list_nested_odicts(self):
|
||||
with patch('msgpack.version', (0, 1, 13)):
|
||||
msgpack.dumps = MockWraps(
|
||||
msgpack.dumps, 1, TypeError('ODict TypeError Forced')
|
||||
)
|
||||
payload = salt.payload.Serial('msgpack')
|
||||
idata = {'pillar': [OrderedDict(environment='dev')]}
|
||||
odata = payload.loads(payload.dumps(idata.copy()))
|
||||
self.assertNoOrderedDict(odata)
|
||||
self.assertEqual(idata, odata)
|
||||
|
||||
def test_datetime_dump_load(self):
|
||||
'''
|
||||
Check the custom datetime handler can understand itself
|
||||
'''
|
||||
payload = salt.payload.Serial('msgpack')
|
||||
dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
|
||||
idata = {dtvalue: dtvalue}
|
||||
sdata = payload.dumps(idata.copy())
|
||||
odata = payload.loads(sdata)
|
||||
self.assertEqual(
|
||||
sdata,
|
||||
b'\x81\xc7\x18N20010203T04:05:06.000007\xc7\x18N20010203T04:05:06.000007')
|
||||
self.assertEqual(idata, odata)
|
||||
|
||||
def test_verylong_dump_load(self):
|
||||
'''
|
||||
Test verylong encoder/decoder
|
||||
'''
|
||||
payload = salt.payload.Serial('msgpack')
|
||||
idata = {'jid': 20180227140750302662}
|
||||
sdata = payload.dumps(idata.copy())
|
||||
odata = payload.loads(sdata)
|
||||
idata['jid'] = '{0}'.format(idata['jid'])
|
||||
self.assertEqual(idata, odata)
|
||||
|
||||
def test_immutable_dict_dump_load(self):
|
||||
'''
|
||||
Test immutable dict encoder/decoder
|
||||
'''
|
||||
payload = salt.payload.Serial('msgpack')
|
||||
idata = {'dict': {'key': 'value'}}
|
||||
sdata = payload.dumps({'dict': immutabletypes.ImmutableDict(idata['dict'])})
|
||||
odata = payload.loads(sdata)
|
||||
self.assertEqual(idata, odata)
|
||||
|
||||
def test_immutable_list_dump_load(self):
|
||||
'''
|
||||
Test immutable list encoder/decoder
|
||||
'''
|
||||
payload = salt.payload.Serial('msgpack')
|
||||
idata = {'list': [1, 2, 3]}
|
||||
sdata = payload.dumps({'list': immutabletypes.ImmutableList(idata['list'])})
|
||||
odata = payload.loads(sdata)
|
||||
self.assertEqual(idata, odata)
|
||||
|
||||
def test_immutable_set_dump_load(self):
|
||||
'''
|
||||
Test immutable set encoder/decoder
|
||||
'''
|
||||
payload = salt.payload.Serial('msgpack')
|
||||
idata = {'set': ['red', 'green', 'blue']}
|
||||
sdata = payload.dumps({'set': immutabletypes.ImmutableSet(idata['set'])})
|
||||
odata = payload.loads(sdata)
|
||||
self.assertEqual(idata, odata)
|
||||
|
||||
def test_odict_dump_load(self):
|
||||
'''
|
||||
Test odict just works. It wasn't until msgpack 0.2.0
|
||||
'''
|
||||
payload = salt.payload.Serial('msgpack')
|
||||
data = OrderedDict()
|
||||
data['a'] = 'b'
|
||||
data['y'] = 'z'
|
||||
data['j'] = 'k'
|
||||
data['w'] = 'x'
|
||||
sdata = payload.dumps({'set': data})
|
||||
odata = payload.loads(sdata)
|
||||
self.assertEqual({'set': dict(data)}, odata)
|
||||
|
||||
def test_mixed_dump_load(self):
|
||||
'''
|
||||
Test we can handle all exceptions at once
|
||||
'''
|
||||
payload = salt.payload.Serial('msgpack')
|
||||
dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
|
||||
od = OrderedDict()
|
||||
od['a'] = 'b'
|
||||
od['y'] = 'z'
|
||||
od['j'] = 'k'
|
||||
od['w'] = 'x'
|
||||
idata = {dtvalue: dtvalue, # datetime
|
||||
'jid': 20180227140750302662, # long int
|
||||
'dict': immutabletypes.ImmutableDict({'key': 'value'}), # immutable dict
|
||||
'list': immutabletypes.ImmutableList([1, 2, 3]), # immutable list
|
||||
'set': immutabletypes.ImmutableSet(('red', 'green', 'blue')), # immutable set
|
||||
'odict': od, # odict
|
||||
}
|
||||
edata = {dtvalue: dtvalue, # datetime, == input
|
||||
'jid': '20180227140750302662', # string repr of long int
|
||||
'dict': {'key': 'value'}, # builtin dict
|
||||
'list': [1, 2, 3], # builtin list
|
||||
'set': ['red', 'green', 'blue'], # builtin set
|
||||
'odict': dict(od), # builtin dict
|
||||
}
|
||||
sdata = payload.dumps(idata)
|
||||
odata = payload.loads(sdata)
|
||||
self.assertEqual(edata, odata)
|
||||
|
||||
|
||||
class SREQTestCase(TestCase):
|
||||
port = 8845 # TODO: dynamically assign a port?
|
||||
|
Loading…
Reference in New Issue
Block a user