Merge pull request #46269 from DSRCorporation/bugs/46202_msgpack_custom_types_rework

msgpack custom types rework
This commit is contained in:
Nicole Thomas 2018-03-02 11:22:55 -05:00 committed by GitHub
commit 00ca4f01c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 151 additions and 113 deletions

View File

@ -17,6 +17,7 @@ import salt.log
import salt.crypt
import salt.transport.frame
import salt.utils.immutabletypes as immutabletypes
import salt.utils.stringutils
from salt.exceptions import SaltReqTimeoutError
# Import third party libs
@ -128,15 +129,21 @@ class Serial(object):
the contents cannot be converted.
'''
try:
def ext_type_decoder(code, data):
if code == 78:
data = salt.utils.stringutils.to_unicode(data)
return datetime.datetime.strptime(data, '%Y%m%dT%H:%M:%S.%f')
return data
gc.disable() # performance optimization for msgpack
if msgpack.version >= (0, 4, 0):
# msgpack only supports 'encoding' starting in 0.4.0.
# Due to this, if we don't need it, don't pass it at all so
# that under Python 2 we can still work with older versions
# of msgpack.
ret = msgpack.loads(msg, use_list=True, encoding=encoding)
ret = msgpack.loads(msg, use_list=True, ext_hook=ext_type_decoder, encoding=encoding)
else:
ret = msgpack.loads(msg, use_list=True)
ret = msgpack.loads(msg, use_list=True, ext_hook=ext_type_decoder)
if six.PY3 and encoding is None and not raw:
ret = salt.transport.frame.decode_embedded_strs(ret)
except Exception as exc:
@ -175,19 +182,40 @@ class Serial(object):
Since this changes the wire protocol, this
option should not be used outside of IPC.
'''
def ext_type_encoder(obj):
if isinstance(obj, six.integer_types):
# msgpack can't handle the very long Python longs for jids
# Convert any very long longs to strings
return six.text_type(obj)
elif isinstance(obj, datetime.datetime):
# msgpack doesn't support datetime.datetime datatype
# So here we have converted datetime.datetime to custom datatype
# This is msgpack Extended types numbered 78
return msgpack.ExtType(78, salt.utils.stringutils.to_bytes(
obj.strftime('%Y%m%dT%H:%M:%S.%f')))
# The same for immutable types
elif isinstance(obj, immutabletypes.ImmutableDict):
return dict(obj)
elif isinstance(obj, immutabletypes.ImmutableList):
return list(obj)
elif isinstance(obj, (set, immutabletypes.ImmutableSet)):
# msgpack can't handle set so translate it to tuple
return tuple(obj)
# Nothing known exceptions found. Let msgpack raise it's own.
return obj
try:
if msgpack.version >= (0, 4, 0):
# msgpack only supports 'use_bin_type' starting in 0.4.0.
# Due to this, if we don't need it, don't pass it at all so
# that under Python 2 we can still work with older versions
# of msgpack.
return msgpack.dumps(msg, use_bin_type=use_bin_type)
return msgpack.dumps(msg, default=ext_type_encoder, use_bin_type=use_bin_type)
else:
return msgpack.dumps(msg)
return msgpack.dumps(msg, default=ext_type_encoder)
except (OverflowError, msgpack.exceptions.PackValueError):
# msgpack can't handle the very long Python longs for jids
# Convert any very long longs to strings
# We borrow the technique used by TypeError below
# msgpack<=0.4.6 don't call ext encoder on very long integers raising the error instead.
# Convert any very long longs to strings and call dumps again.
def verylong_encoder(obj):
if isinstance(obj, dict):
for key, value in six.iteritems(obj.copy()):
@ -198,102 +226,18 @@ class Serial(object):
for idx, entry in enumerate(obj):
obj[idx] = verylong_encoder(entry)
return obj
# This is a spurious lint failure as we are gating this check
# behind a check for six.PY2.
if six.PY2 and isinstance(obj, long) and long > pow(2, 64): # pylint: disable=incompatible-py3-code
return six.text_type(obj)
elif six.PY3 and isinstance(obj, int) and int > pow(2, 64):
# A value of an Integer object is limited from -(2^63) upto (2^64)-1 by MessagePack
# spec. Here we care only of JIDs that are positive integers.
if isinstance(obj, six.integer_types) and obj >= pow(2, 64):
return six.text_type(obj)
else:
return obj
msg = verylong_encoder(msg)
if msgpack.version >= (0, 4, 0):
return msgpack.dumps(verylong_encoder(msg), use_bin_type=use_bin_type)
return msgpack.dumps(msg, default=ext_type_encoder, use_bin_type=use_bin_type)
else:
return msgpack.dumps(verylong_encoder(msg))
except TypeError as e:
# msgpack doesn't support datetime.datetime datatype
# So here we have converted datetime.datetime to custom datatype
# This is msgpack Extended types numbered 78
def default(obj):
return msgpack.ExtType(78, obj)
def dt_encode(obj):
datetime_str = obj.strftime("%Y%m%dT%H:%M:%S.%f")
if msgpack.version >= (0, 4, 0):
return msgpack.packb(datetime_str, default=default, use_bin_type=use_bin_type)
else:
return msgpack.packb(datetime_str, default=default)
def datetime_encoder(obj):
if isinstance(obj, dict):
for key, value in six.iteritems(obj.copy()):
encodedkey = datetime_encoder(key)
if key != encodedkey:
del obj[key]
key = encodedkey
obj[key] = datetime_encoder(value)
return dict(obj)
elif isinstance(obj, (list, tuple)):
obj = list(obj)
for idx, entry in enumerate(obj):
obj[idx] = datetime_encoder(entry)
return obj
if isinstance(obj, datetime.datetime):
return dt_encode(obj)
else:
return obj
def immutable_encoder(obj):
log.debug('IMMUTABLE OBJ: %s', obj)
if isinstance(obj, immutabletypes.ImmutableDict):
return dict(obj)
if isinstance(obj, immutabletypes.ImmutableList):
return list(obj)
if isinstance(obj, immutabletypes.ImmutableSet):
return set(obj)
if "datetime.datetime" in six.text_type(e):
if msgpack.version >= (0, 4, 0):
return msgpack.dumps(datetime_encoder(msg), use_bin_type=use_bin_type)
else:
return msgpack.dumps(datetime_encoder(msg))
elif "Immutable" in six.text_type(e):
if msgpack.version >= (0, 4, 0):
return msgpack.dumps(msg, default=immutable_encoder, use_bin_type=use_bin_type)
else:
return msgpack.dumps(msg, default=immutable_encoder)
if msgpack.version >= (0, 2, 0):
# Should support OrderedDict serialization, so, let's
# raise the exception
raise
# msgpack is < 0.2.0, let's make its life easier
# Since OrderedDict is identified as a dictionary, we can't
# make use of msgpack custom types, we will need to convert by
# hand.
# This means iterating through all elements of a dictionary or
# list/tuple
def odict_encoder(obj):
if isinstance(obj, dict):
for key, value in six.iteritems(obj.copy()):
obj[key] = odict_encoder(value)
return dict(obj)
elif isinstance(obj, (list, tuple)):
obj = list(obj)
for idx, entry in enumerate(obj):
obj[idx] = odict_encoder(entry)
return obj
return obj
if msgpack.version >= (0, 4, 0):
return msgpack.dumps(odict_encoder(msg), use_bin_type=use_bin_type)
else:
return msgpack.dumps(odict_encoder(msg))
except (SystemError, TypeError) as exc: # pylint: disable=W0705
log.critical(
'Unable to serialize message! Consider upgrading msgpack. '
'Message which failed was %s, with exception %s', msg, exc
)
return msgpack.dumps(msg, default=ext_type_encoder)
def dump(self, msg, fn_):
'''

View File

@ -7,24 +7,24 @@
~~~~~~~~~~~~~~~~~~~~~~~
'''
# Import Salt libs
# Import python libs
from __future__ import absolute_import, print_function, unicode_literals
import time
import errno
import threading
import datetime
# Import Salt Testing libs
from tests.support.unit import skipIf, TestCase
from tests.support.helpers import MockWraps
from tests.support.mock import NO_MOCK, NO_MOCK_REASON, patch
from tests.support.mock import NO_MOCK, NO_MOCK_REASON
# Import salt libs
import salt.payload
# Import Salt libs
from salt.utils import immutabletypes
from salt.utils.odict import OrderedDict
import salt.exceptions
import salt.payload
# Import 3rd-party libs
import msgpack
import zmq
from salt.ext import six
@ -49,16 +49,110 @@ class PayloadTestCase(TestCase):
self.assertNoOrderedDict(chunk)
def test_list_nested_odicts(self):
with patch('msgpack.version', (0, 1, 13)):
msgpack.dumps = MockWraps(
msgpack.dumps, 1, TypeError('ODict TypeError Forced')
)
payload = salt.payload.Serial('msgpack')
idata = {'pillar': [OrderedDict(environment='dev')]}
odata = payload.loads(payload.dumps(idata.copy()))
self.assertNoOrderedDict(odata)
self.assertEqual(idata, odata)
def test_datetime_dump_load(self):
'''
Check the custom datetime handler can understand itself
'''
payload = salt.payload.Serial('msgpack')
dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
idata = {dtvalue: dtvalue}
sdata = payload.dumps(idata.copy())
odata = payload.loads(sdata)
self.assertEqual(
sdata,
b'\x81\xc7\x18N20010203T04:05:06.000007\xc7\x18N20010203T04:05:06.000007')
self.assertEqual(idata, odata)
def test_verylong_dump_load(self):
'''
Test verylong encoder/decoder
'''
payload = salt.payload.Serial('msgpack')
idata = {'jid': 20180227140750302662}
sdata = payload.dumps(idata.copy())
odata = payload.loads(sdata)
idata['jid'] = '{0}'.format(idata['jid'])
self.assertEqual(idata, odata)
def test_immutable_dict_dump_load(self):
'''
Test immutable dict encoder/decoder
'''
payload = salt.payload.Serial('msgpack')
idata = {'dict': {'key': 'value'}}
sdata = payload.dumps({'dict': immutabletypes.ImmutableDict(idata['dict'])})
odata = payload.loads(sdata)
self.assertEqual(idata, odata)
def test_immutable_list_dump_load(self):
'''
Test immutable list encoder/decoder
'''
payload = salt.payload.Serial('msgpack')
idata = {'list': [1, 2, 3]}
sdata = payload.dumps({'list': immutabletypes.ImmutableList(idata['list'])})
odata = payload.loads(sdata)
self.assertEqual(idata, odata)
def test_immutable_set_dump_load(self):
'''
Test immutable set encoder/decoder
'''
payload = salt.payload.Serial('msgpack')
idata = {'set': ['red', 'green', 'blue']}
sdata = payload.dumps({'set': immutabletypes.ImmutableSet(idata['set'])})
odata = payload.loads(sdata)
self.assertEqual(idata, odata)
def test_odict_dump_load(self):
'''
Test odict just works. It wasn't until msgpack 0.2.0
'''
payload = salt.payload.Serial('msgpack')
data = OrderedDict()
data['a'] = 'b'
data['y'] = 'z'
data['j'] = 'k'
data['w'] = 'x'
sdata = payload.dumps({'set': data})
odata = payload.loads(sdata)
self.assertEqual({'set': dict(data)}, odata)
def test_mixed_dump_load(self):
'''
Test we can handle all exceptions at once
'''
payload = salt.payload.Serial('msgpack')
dtvalue = datetime.datetime(2001, 2, 3, 4, 5, 6, 7)
od = OrderedDict()
od['a'] = 'b'
od['y'] = 'z'
od['j'] = 'k'
od['w'] = 'x'
idata = {dtvalue: dtvalue, # datetime
'jid': 20180227140750302662, # long int
'dict': immutabletypes.ImmutableDict({'key': 'value'}), # immutable dict
'list': immutabletypes.ImmutableList([1, 2, 3]), # immutable list
'set': immutabletypes.ImmutableSet(('red', 'green', 'blue')), # immutable set
'odict': od, # odict
}
edata = {dtvalue: dtvalue, # datetime, == input
'jid': '20180227140750302662', # string repr of long int
'dict': {'key': 'value'}, # builtin dict
'list': [1, 2, 3], # builtin list
'set': ['red', 'green', 'blue'], # builtin set
'odict': dict(od), # builtin dict
}
sdata = payload.dumps(idata)
odata = payload.loads(sdata)
self.assertEqual(edata, odata)
class SREQTestCase(TestCase):
port = 8845 # TODO: dynamically assign a port?