Merge pull request #46269 from DSRCorporation/bugs/46202_msgpack_custom_types_rework

msgpack custom types rework
This commit is contained in:
Nicole Thomas 2018-03-02 11:22:55 -05:00 committed by GitHub
commit 00ca4f01c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 151 additions and 113 deletions

View File

@ -17,6 +17,7 @@ import salt.log
import salt.crypt import salt.crypt
import salt.transport.frame import salt.transport.frame
import salt.utils.immutabletypes as immutabletypes import salt.utils.immutabletypes as immutabletypes
import salt.utils.stringutils
from salt.exceptions import SaltReqTimeoutError from salt.exceptions import SaltReqTimeoutError
# Import third party libs # Import third party libs
@ -128,15 +129,21 @@ class Serial(object):
the contents cannot be converted. the contents cannot be converted.
''' '''
try: try:
def ext_type_decoder(code, data):
if code == 78:
data = salt.utils.stringutils.to_unicode(data)
return datetime.datetime.strptime(data, '%Y%m%dT%H:%M:%S.%f')
return data
gc.disable() # performance optimization for msgpack gc.disable() # performance optimization for msgpack
if msgpack.version >= (0, 4, 0): if msgpack.version >= (0, 4, 0):
# msgpack only supports 'encoding' starting in 0.4.0. # msgpack only supports 'encoding' starting in 0.4.0.
# Due to this, if we don't need it, don't pass it at all so # Due to this, if we don't need it, don't pass it at all so
# that under Python 2 we can still work with older versions # that under Python 2 we can still work with older versions
# of msgpack. # of msgpack.
ret = msgpack.loads(msg, use_list=True, encoding=encoding) ret = msgpack.loads(msg, use_list=True, ext_hook=ext_type_decoder, encoding=encoding)
else: else:
ret = msgpack.loads(msg, use_list=True) ret = msgpack.loads(msg, use_list=True, ext_hook=ext_type_decoder)
if six.PY3 and encoding is None and not raw: if six.PY3 and encoding is None and not raw:
ret = salt.transport.frame.decode_embedded_strs(ret) ret = salt.transport.frame.decode_embedded_strs(ret)
except Exception as exc: except Exception as exc:
@ -175,19 +182,40 @@ class Serial(object):
Since this changes the wire protocol, this Since this changes the wire protocol, this
option should not be used outside of IPC. option should not be used outside of IPC.
''' '''
def ext_type_encoder(obj):
if isinstance(obj, six.integer_types):
# msgpack can't handle the very long Python longs for jids
# Convert any very long longs to strings
return six.text_type(obj)
elif isinstance(obj, datetime.datetime):
# msgpack doesn't support datetime.datetime datatype
# So here we have converted datetime.datetime to custom datatype
# This is msgpack Extended types numbered 78
return msgpack.ExtType(78, salt.utils.stringutils.to_bytes(
obj.strftime('%Y%m%dT%H:%M:%S.%f')))
# The same for immutable types
elif isinstance(obj, immutabletypes.ImmutableDict):
return dict(obj)
elif isinstance(obj, immutabletypes.ImmutableList):
return list(obj)
elif isinstance(obj, (set, immutabletypes.ImmutableSet)):
# msgpack can't handle set so translate it to tuple
return tuple(obj)
# Nothing known exceptions found. Let msgpack raise it's own.
return obj
try: try:
if msgpack.version >= (0, 4, 0): if msgpack.version >= (0, 4, 0):
# msgpack only supports 'use_bin_type' starting in 0.4.0. # msgpack only supports 'use_bin_type' starting in 0.4.0.
# Due to this, if we don't need it, don't pass it at all so # Due to this, if we don't need it, don't pass it at all so
# that under Python 2 we can still work with older versions # that under Python 2 we can still work with older versions
# of msgpack. # of msgpack.
return msgpack.dumps(msg, use_bin_type=use_bin_type) return msgpack.dumps(msg, default=ext_type_encoder, use_bin_type=use_bin_type)
else: else:
return msgpack.dumps(msg) return msgpack.dumps(msg, default=ext_type_encoder)
except (OverflowError, msgpack.exceptions.PackValueError): except (OverflowError, msgpack.exceptions.PackValueError):
# msgpack can't handle the very long Python longs for jids # msgpack<=0.4.6 don't call ext encoder on very long integers raising the error instead.
# Convert any very long longs to strings # Convert any very long longs to strings and call dumps again.
# We borrow the technique used by TypeError below
def verylong_encoder(obj): def verylong_encoder(obj):
if isinstance(obj, dict): if isinstance(obj, dict):
for key, value in six.iteritems(obj.copy()): for key, value in six.iteritems(obj.copy()):
@ -198,102 +226,18 @@ class Serial(object):
for idx, entry in enumerate(obj): for idx, entry in enumerate(obj):
obj[idx] = verylong_encoder(entry) obj[idx] = verylong_encoder(entry)
return obj return obj
# This is a spurious lint failure as we are gating this check # A value of an Integer object is limited from -(2^63) upto (2^64)-1 by MessagePack
# behind a check for six.PY2. # spec. Here we care only of JIDs that are positive integers.
if six.PY2 and isinstance(obj, long) and long > pow(2, 64): # pylint: disable=incompatible-py3-code if isinstance(obj, six.integer_types) and obj >= pow(2, 64):
return six.text_type(obj)
elif six.PY3 and isinstance(obj, int) and int > pow(2, 64):
return six.text_type(obj) return six.text_type(obj)
else: else:
return obj return obj
msg = verylong_encoder(msg)
if msgpack.version >= (0, 4, 0): if msgpack.version >= (0, 4, 0):
return msgpack.dumps(verylong_encoder(msg), use_bin_type=use_bin_type) return msgpack.dumps(msg, default=ext_type_encoder, use_bin_type=use_bin_type)
else: else:
return msgpack.dumps(verylong_encoder(msg)) return msgpack.dumps(msg, default=ext_type_encoder)
except TypeError as e:
# msgpack doesn't support datetime.datetime datatype
# So here we have converted datetime.datetime to custom datatype
# This is msgpack Extended types numbered 78
def default(obj):
return msgpack.ExtType(78, obj)
def dt_encode(obj):
datetime_str = obj.strftime("%Y%m%dT%H:%M:%S.%f")
if msgpack.version >= (0, 4, 0):
return msgpack.packb(datetime_str, default=default, use_bin_type=use_bin_type)
else:
return msgpack.packb(datetime_str, default=default)
def datetime_encoder(obj):
if isinstance(obj, dict):
for key, value in six.iteritems(obj.copy()):
encodedkey = datetime_encoder(key)
if key != encodedkey:
del obj[key]
key = encodedkey
obj[key] = datetime_encoder(value)
return dict(obj)
elif isinstance(obj, (list, tuple)):
obj = list(obj)
for idx, entry in enumerate(obj):
obj[idx] = datetime_encoder(entry)
return obj
if isinstance(obj, datetime.datetime):
return dt_encode(obj)
else:
return obj
def immutable_encoder(obj):
log.debug('IMMUTABLE OBJ: %s', obj)
if isinstance(obj, immutabletypes.ImmutableDict):
return dict(obj)
if isinstance(obj, immutabletypes.ImmutableList):
return list(obj)
if isinstance(obj, immutabletypes.ImmutableSet):
return set(obj)
if "datetime.datetime" in six.text_type(e):
if msgpack.version >= (0, 4, 0):
return msgpack.dumps(datetime_encoder(msg), use_bin_type=use_bin_type)
else:
return msgpack.dumps(datetime_encoder(msg))
elif "Immutable" in six.text_type(e):
if msgpack.version >= (0, 4, 0):
return msgpack.dumps(msg, default=immutable_encoder, use_bin_type=use_bin_type)
else:
return msgpack.dumps(msg, default=immutable_encoder)
if msgpack.version >= (0, 2, 0):
# Should support OrderedDict serialization, so, let's
# raise the exception
raise
# msgpack is < 0.2.0, let's make its life easier
# Since OrderedDict is identified as a dictionary, we can't
# make use of msgpack custom types, we will need to convert by
# hand.
# This means iterating through all elements of a dictionary or
# list/tuple
def odict_encoder(obj):
if isinstance(obj, dict):
for key, value in six.iteritems(obj.copy()):
obj[key] = odict_encoder(value)
return dict(obj)
elif isinstance(obj, (list, tuple)):
obj = list(obj)
for idx, entry in enumerate(obj):
obj[idx] = odict_encoder(entry)
return obj
return obj
if msgpack.version >= (0, 4, 0):
return msgpack.dumps(odict_encoder(msg), use_bin_type=use_bin_type)
else:
return msgpack.dumps(odict_encoder(msg))
except (SystemError, TypeError) as exc: # pylint: disable=W0705
log.critical(
'Unable to serialize message! Consider upgrading msgpack. '
'Message which failed was %s, with exception %s', msg, exc
)
def dump(self, msg, fn_): def dump(self, msg, fn_):
''' '''

View File

@ -7,24 +7,24 @@
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~
''' '''
# Import Salt libs # Import python libs
from __future__ import absolute_import, print_function, unicode_literals from __future__ import absolute_import, print_function, unicode_literals
import time import time
import errno import errno
import threading import threading
import datetime
# Import Salt Testing libs # Import Salt Testing libs
from tests.support.unit import skipIf, TestCase from tests.support.unit import skipIf, TestCase
from tests.support.helpers import MockWraps from tests.support.mock import NO_MOCK, NO_MOCK_REASON
from tests.support.mock import NO_MOCK, NO_MOCK_REASON, patch
# Import salt libs # Import Salt libs
import salt.payload from salt.utils import immutabletypes
from salt.utils.odict import OrderedDict from salt.utils.odict import OrderedDict
import salt.exceptions import salt.exceptions
import salt.payload
# Import 3rd-party libs # Import 3rd-party libs
import msgpack
import zmq import zmq
from salt.ext import six from salt.ext import six
@ -49,15 +49,109 @@ class PayloadTestCase(TestCase):
self.assertNoOrderedDict(chunk) self.assertNoOrderedDict(chunk)
def test_list_nested_odicts(self): def test_list_nested_odicts(self):
with patch('msgpack.version', (0, 1, 13)): payload = salt.payload.Serial('msgpack')
msgpack.dumps = MockWraps( idata = {'pillar': [OrderedDict(environment='dev')]}
msgpack.dumps, 1, TypeError('ODict TypeError Forced') odata = payload.loads(payload.dumps(idata.copy()))
) self.assertNoOrderedDict(odata)
payload = salt.payload.Serial('msgpack') self.assertEqual(idata, odata)
idata = {'pillar': [OrderedDict(environment='dev')]}
odata = payload.loads(payload.dumps(idata.copy())) def test_datetime_dump_load(self):
self.assertNoOrderedDict(odata) '''
self.assertEqual(idata, odata) Check the custom datetime handler can understand itself
'''
payload = salt.payload.Serial('msgpack')
dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
idata = {dtvalue: dtvalue}
sdata = payload.dumps(idata.copy())
odata = payload.loads(sdata)
self.assertEqual(
sdata,
b'\x81\xc7\x18N20010203T04:05:06.000007\xc7\x18N20010203T04:05:06.000007')
self.assertEqual(idata, odata)
def test_verylong_dump_load(self):
'''
Test verylong encoder/decoder
'''
payload = salt.payload.Serial('msgpack')
idata = {'jid': 20180227140750302662}
sdata = payload.dumps(idata.copy())
odata = payload.loads(sdata)
idata['jid'] = '{0}'.format(idata['jid'])
self.assertEqual(idata, odata)
def test_immutable_dict_dump_load(self):
'''
Test immutable dict encoder/decoder
'''
payload = salt.payload.Serial('msgpack')
idata = {'dict': {'key': 'value'}}
sdata = payload.dumps({'dict': immutabletypes.ImmutableDict(idata['dict'])})
odata = payload.loads(sdata)
self.assertEqual(idata, odata)
def test_immutable_list_dump_load(self):
'''
Test immutable list encoder/decoder
'''
payload = salt.payload.Serial('msgpack')
idata = {'list': [1, 2, 3]}
sdata = payload.dumps({'list': immutabletypes.ImmutableList(idata['list'])})
odata = payload.loads(sdata)
self.assertEqual(idata, odata)
def test_immutable_set_dump_load(self):
'''
Test immutable set encoder/decoder
'''
payload = salt.payload.Serial('msgpack')
idata = {'set': ['red', 'green', 'blue']}
sdata = payload.dumps({'set': immutabletypes.ImmutableSet(idata['set'])})
odata = payload.loads(sdata)
self.assertEqual(idata, odata)
def test_odict_dump_load(self):
'''
Test odict just works. It wasn't until msgpack 0.2.0
'''
payload = salt.payload.Serial('msgpack')
data = OrderedDict()
data['a'] = 'b'
data['y'] = 'z'
data['j'] = 'k'
data['w'] = 'x'
sdata = payload.dumps({'set': data})
odata = payload.loads(sdata)
self.assertEqual({'set': dict(data)}, odata)
def test_mixed_dump_load(self):
'''
Test we can handle all exceptions at once
'''
payload = salt.payload.Serial('msgpack')
dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
od = OrderedDict()
od['a'] = 'b'
od['y'] = 'z'
od['j'] = 'k'
od['w'] = 'x'
idata = {dtvalue: dtvalue, # datetime
'jid': 20180227140750302662, # long int
'dict': immutabletypes.ImmutableDict({'key': 'value'}), # immutable dict
'list': immutabletypes.ImmutableList([1, 2, 3]), # immutable list
'set': immutabletypes.ImmutableSet(('red', 'green', 'blue')), # immutable set
'odict': od, # odict
}
edata = {dtvalue: dtvalue, # datetime, == input
'jid': '20180227140750302662', # string repr of long int
'dict': {'key': 'value'}, # builtin dict
'list': [1, 2, 3], # builtin list
'set': ['red', 'green', 'blue'], # builtin set
'odict': dict(od), # builtin dict
}
sdata = payload.dumps(idata)
odata = payload.loads(sdata)
self.assertEqual(edata, odata)
class SREQTestCase(TestCase): class SREQTestCase(TestCase):