From 792f05403c4f3bd16fdce355605026d0f9dac9c2 Mon Sep 17 00:00:00 2001 From: Erik Johnson Date: Wed, 30 Jan 2019 15:13:16 -0600 Subject: [PATCH] Add a CaseInsensitiveDict implementation --- salt/output/nested.py | 7 ++-- salt/payload.py | 3 ++ salt/utils/data.py | 86 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 91 insertions(+), 5 deletions(-) diff --git a/salt/output/nested.py b/salt/output/nested.py index bb92644202..b80a90629c 100644 --- a/salt/output/nested.py +++ b/salt/output/nested.py @@ -33,6 +33,7 @@ import salt.utils.color import salt.utils.odict import salt.utils.stringutils from salt.ext import six +from collections import Mapping class NestDisplay(object): @@ -142,7 +143,7 @@ class NestDisplay(object): if self.retcode != 0: color = self.RED for ind in ret: - if isinstance(ind, (list, tuple, dict)): + if isinstance(ind, (list, tuple, Mapping)): out.append( self.ustring( indent, @@ -150,11 +151,11 @@ class NestDisplay(object): '|_' ) ) - prefix = '' if isinstance(ind, dict) else '- ' + prefix = '' if isinstance(ind, Mapping) else '- ' self.display(ind, indent + 2, prefix, out) else: self.display(ind, indent, '- ', out) - elif isinstance(ret, dict): + elif isinstance(ret, Mapping): if indent: color = self.CYAN if self.retcode != 0: diff --git a/salt/payload.py b/salt/payload.py index 5df6458b86..1e1f08846c 100644 --- a/salt/payload.py +++ b/salt/payload.py @@ -19,6 +19,7 @@ import salt.transport.frame import salt.utils.immutabletypes as immutabletypes import salt.utils.stringutils from salt.exceptions import SaltReqTimeoutError +from salt.utils.data import CaseInsensitiveDict # Import third party libs from salt.ext import six @@ -205,6 +206,8 @@ class Serial(object): elif isinstance(obj, (set, immutabletypes.ImmutableSet)): # msgpack can't handle set so translate it to tuple return tuple(obj) + elif isinstance(obj, CaseInsensitiveDict): + return dict(obj) # Nothing known exceptions found. Let msgpack raise it's own. return obj diff --git a/salt/utils/data.py b/salt/utils/data.py index c69ba900e2..1621946502 100644 --- a/salt/utils/data.py +++ b/salt/utils/data.py @@ -13,9 +13,9 @@ import logging import re try: - from collections.abc import Mapping + from collections.abc import Mapping, MutableMapping, Sequence except ImportError: - from collections import Mapping + from collections import Mapping, MutableMapping, Sequence # Import Salt libs import salt.utils.dictupdate @@ -24,6 +24,7 @@ import salt.utils.yaml from salt.defaults import DEFAULT_TARGET_DELIM from salt.exceptions import SaltException from salt.utils.decorators.jinja import jinja_filter +from salt.utils.odict import OrderedDict # Import 3rd-party libs from salt.ext import six @@ -32,6 +33,87 @@ from salt.ext.six.moves import range # pylint: disable=redefined-builtin log = logging.getLogger(__name__) +class CaseInsensitiveDict(MutableMapping): + ''' + Inspired by requests' case-insensitive dict implementation, but works with + non-string keys as well. + ''' + def __init__(self, init=None): + ''' + Force internal dict to be ordered to ensure a consistent iteration + order, irrespective of case. + ''' + self._data = OrderedDict() + self.update(init or {}) + + def __len__(self): + return len(self._data) + + def __setitem__(self, key, value): + # Store the case-sensitive key so it is available for dict iteration + self._data[to_lowercase(key)] = (key, value) + + def __delitem__(self, key): + del self._data[to_lowercase(key)] + + def __getitem__(self, key): + return self._data[to_lowercase(key)][1] + + def __iter__(self): + return (item[0] for item in six.itervalues(self._data)) + + def __eq__(self, rval): + if not isinstance(rval, Mapping): + # Comparing to non-mapping type (e.g. int) is always False + return False + return dict(self.items_lower()) == dict(CaseInsensitiveDict(rval).items_lower()) + + def __repr__(self): + return repr(dict(six.iteritems(self))) + + def items_lower(self): + ''' + Returns a generator iterating over keys and values, with the keys all + being lowercase. + ''' + return ((key, val[1]) for key, val in six.iteritems(self._data)) + + def copy(self): + ''' + Returns a copy of the object + ''' + return CaseInsensitiveDict(six.iteritems(self._data)) + + +def __change_case(data, attr, preserve_dict_class=False): + try: + return getattr(data, attr)() + except AttributeError: + pass + + data_type = data.__class__ + + if isinstance(data, Mapping): + return (data_type if preserve_dict_class else dict)( + (__change_case(key, attr, preserve_dict_class), + __change_case(val, attr, preserve_dict_class)) + for key, val in six.iteritems(data) + ) + elif isinstance(data, Sequence): + return data_type( + __change_case(item, attr, preserve_dict_class) for item in data) + else: + return data + + +def to_lowercase(data, preserve_dict_class=False): + return __change_case(data, 'lower', preserve_dict_class) + + +def to_uppercase(data, preserve_dict_class=False): + return __change_case(data, 'upper', preserve_dict_class) + + @jinja_filter('compare_dicts') def compare_dicts(old=None, new=None): '''