Merge pull request #41559 from rallytime/merge-develop

[develop] Merge forward from nitrogen to develop
This commit is contained in:
Mike Place 2017-06-02 10:28:29 -05:00 committed by GitHub
commit 3190b88df0
21 changed files with 778 additions and 479 deletions

View File

@ -27,7 +27,6 @@ except Exception:
# pylint: disable=import-error,no-name-in-module
import salt.ext.six as six
from salt.ext.six import string_types, text_type
from salt.ext.six.moves.urllib.parse import urlparse
# pylint: enable=import-error,no-name-in-module
@ -108,10 +107,10 @@ FLO_DIR = os.path.join(
VALID_OPTS = {
# The address of the salt master. May be specified as IP address or hostname
'master': (string_types, list),
'master': (six.string_types, list),
# The TCP/UDP port of the master to connect to in order to listen to publications
'master_port': (string_types, int),
'master_port': (six.string_types, int),
# The behaviour of the minion when connecting to a master. Can specify 'failover',
# 'disable' or 'func'. If 'func' is specified, the 'master' option should be set to an
@ -460,7 +459,7 @@ VALID_OPTS = {
# If a minion is running an esky build of salt, upgrades can be performed using the url
# defined here. See saltutil.update() for additional information
'update_url': (bool, string_types),
'update_url': (bool, six.string_types),
# If using update_url with saltutil.update(), provide a list of services to be restarted
# post-install
@ -489,7 +488,7 @@ VALID_OPTS = {
# Specify one or more returners in which all events will be sent to. Requires that the returners
# in question have an event_return(event) function!
'event_return': (list, string_types),
'event_return': (list, six.string_types),
# The number of events to queue up in memory before pushing them down the pipe to an event
# returner specified by 'event_return'
@ -679,7 +678,7 @@ VALID_OPTS = {
'ping_on_rotate': bool,
'peer': dict,
'preserve_minion_cache': bool,
'syndic_master': (string_types, list),
'syndic_master': (six.string_types, list),
# The behaviour of the multimaster syndic when connection to a master of masters failed. Can
# specify 'random' (default) or 'ordered'. If set to 'random' masters will be iterated in random
@ -696,8 +695,8 @@ VALID_OPTS = {
'token_expire_user_override': (bool, dict),
'file_recv': bool,
'file_recv_max_size': int,
'file_ignore_regex': (list, string_types),
'file_ignore_glob': (list, string_types),
'file_ignore_regex': (list, six.string_types),
'file_ignore_glob': (list, six.string_types),
'fileserver_backend': list,
'fileserver_followsymlinks': bool,
'fileserver_ignoresymlinks': bool,
@ -1924,7 +1923,7 @@ def _read_conf_file(path):
else:
conf_opts['id'] = sdecode(conf_opts['id'])
for key, value in six.iteritems(conf_opts.copy()):
if isinstance(value, text_type) and six.PY2:
if isinstance(value, six.text_type) and six.PY2:
# We do not want unicode settings
conf_opts[key] = value.encode('utf-8')
return conf_opts
@ -2029,7 +2028,7 @@ def include_config(include, orig_path, verbose, exit_on_config_errors=False):
# defaults, not actually loading the whole configuration.
return {}
if isinstance(include, str):
if isinstance(include, six.string_types):
include = [include]
configuration = {}
@ -2116,7 +2115,7 @@ def insert_system_path(opts, paths):
'''
Inserts path into python path taking into consideration 'root_dir' option.
'''
if isinstance(paths, str):
if isinstance(paths, six.string_types):
paths = [paths]
for path in paths:
path_options = {'path': path, 'root_dir': opts['root_dir']}
@ -2307,7 +2306,7 @@ def apply_sdb(opts, sdb_opts=None):
import salt.utils.sdb
if sdb_opts is None:
sdb_opts = opts
if isinstance(sdb_opts, string_types) and sdb_opts.startswith('sdb://'):
if isinstance(sdb_opts, six.string_types) and sdb_opts.startswith('sdb://'):
return salt.utils.sdb.sdb_get(sdb_opts, opts)
elif isinstance(sdb_opts, dict):
for key, value in six.iteritems(sdb_opts):
@ -2402,7 +2401,7 @@ def cloud_config(path, env_var='SALT_CLOUD_CONFIG', defaults=None,
'deploy_scripts_search_path',
defaults.get('deploy_scripts_search_path', 'cloud.deploy.d')
)
if isinstance(deploy_scripts_search_path, string_types):
if isinstance(deploy_scripts_search_path, six.string_types):
deploy_scripts_search_path = [deploy_scripts_search_path]
# Check the provided deploy scripts search path removing any non existing
@ -3355,7 +3354,7 @@ def _update_ssl_config(opts):
val = opts['ssl'].get(key)
if val is None:
continue
if not isinstance(val, string_types) or not val.startswith(prefix) or not hasattr(ssl, val):
if not isinstance(val, six.string_types) or not val.startswith(prefix) or not hasattr(ssl, val):
message = 'SSL option \'{0}\' must be set to one of the following values: \'{1}\'.' \
.format(key, '\', \''.join([val for val in dir(ssl) if val.startswith(prefix)]))
log.error(message)

View File

@ -114,7 +114,7 @@ def _mbcs_to_unicode_wrap(obj, vtype):
return obj
if isinstance(obj, list):
return [_mbcs_to_unicode(x) for x in obj]
elif isinstance(obj, int):
elif isinstance(obj, six.integer_types):
return obj
else:
return _mbcs_to_unicode(obj)

View File

@ -156,7 +156,7 @@ def _replace_auth_key(
lines.append(line)
continue
comps = re.findall(r'((.*)\s)?(ssh-[a-z0-9-]+|ecdsa-[a-z0-9-]+)\s([a-zA-Z0-9+/]+={0,2})(\s(.*))?', line)
if comps[0][3] == key:
if len(comps) > 0 and len(comps[0]) > 3 and comps[0][3] == key:
lines.append(auth_line)
else:
lines.append(line)

View File

@ -1610,7 +1610,10 @@ def cert_info(cert_path, digest='sha256'):
if hasattr(cert, 'get_signature_algorithm'):
try:
ret['signature_algorithm'] = cert.get_signature_algorithm()
value = cert.get_signature_algorithm()
if isinstance(value, bytes):
value = salt.utils.to_str(value, __salt_system_encoding__)
ret['signature_algorithm'] = value
except AttributeError:
# On py3 at least
# AttributeError: cdata 'X509 *' points to an opaque type: cannot read fields

View File

@ -74,6 +74,13 @@ try:
except ImportError:
HAS_WINDOWS_MODULES = False
# This is to fix the pylint error: E0602: Undefined variable "WindowsError"
try:
from exceptions import WindowsError
except ImportError:
class WindowsError(OSError):
pass
HAS_WIN_DACL = False
try:
if salt.utils.is_windows():
@ -1247,7 +1254,10 @@ def mkdir(path,
not apply to parent directories if they must be created
Returns:
bool: True if successful, otherwise raise an error
bool: True if successful
Raises:
CommandExecutionError: If unsuccessful
CLI Example:
@ -1272,24 +1282,27 @@ def mkdir(path,
if not os.path.isdir(path):
# Make the directory
os.mkdir(path)
try:
# Make the directory
os.mkdir(path)
# Set owner
if owner:
salt.utils.win_dacl.set_owner(path, owner)
# Set owner
if owner:
salt.utils.win_dacl.set_owner(path, owner)
# Set permissions
set_perms(path, grant_perms, deny_perms, inheritance)
# Set permissions
set_perms(path, grant_perms, deny_perms, inheritance)
except WindowsError as exc:
raise CommandExecutionError(exc)
return True
def makedirs(path,
owner=None,
grant_perms=None,
deny_perms=None,
inheritance=True):
def makedirs_(path,
owner=None,
grant_perms=None,
deny_perms=None,
inheritance=True):
'''
Ensure that the parent directory containing this path is available.
@ -1333,6 +1346,12 @@ def makedirs(path,
the path ends with a trailing slash like ``C:\\temp\\test\\``, then it
would be treated as ``C:\\temp\\test\\``.
Returns:
bool: True if successful
Raises:
CommandExecutionError: If unsuccessful
CLI Example:
.. code-block:: bash
@ -1385,7 +1404,9 @@ def makedirs(path,
for directory_to_create in directories_to_create:
# all directories have the user, group and mode set!!
log.debug('Creating directory: %s', directory_to_create)
mkdir(path, owner, grant_perms, deny_perms, inheritance)
mkdir(directory_to_create, owner, grant_perms, deny_perms, inheritance)
return True
def makedirs_perms(path,

View File

@ -1,36 +1,40 @@
# -*- coding: utf-8 -*-
'''
Module for configuring Windows Firewall
Module for configuring Windows Firewall using ``netsh``
'''
from __future__ import absolute_import
# Import python libs
import re
import logging
# Import salt libs
import salt.utils
from salt.ext import six
from salt.exceptions import CommandExecutionError
# Define the module's virtual name
__virtualname__ = 'firewall'
log = logging.getLogger(__name__)
def __virtual__():
'''
Only works on Windows systems
'''
if salt.utils.is_windows():
return __virtualname__
return (False, "Module win_firewall: module only works on Windows systems")
if not salt.utils.is_windows():
return False, "Module win_firewall: module only available on Windows"
return __virtualname__
def get_config():
'''
Get the status of all the firewall profiles
Returns:
dict: A dictionary of all profiles on the system
Raises:
CommandExecutionError: If the command fails
CLI Example:
.. code-block:: bash
@ -41,7 +45,14 @@ def get_config():
curr = None
cmd = ['netsh', 'advfirewall', 'show', 'allprofiles']
for line in __salt__['cmd.run'](cmd, python_shell=False).splitlines():
ret = __salt__['cmd.run_all'](cmd, python_shell=False, ignore_retcode=True)
if ret['retcode'] != 0:
raise CommandExecutionError(ret['stdout'])
# There may be some problems with this depending on how `netsh` is localized
# It's looking for lines that contain `Profile Settings` or start with
# `State` which may be different in different localizations
for line in ret['stdout'].splitlines():
if not curr:
tmp = re.search('(.*) Profile Settings:', line)
if tmp:
@ -55,7 +66,22 @@ def get_config():
def disable(profile='allprofiles'):
'''
Disable firewall profile :param profile: (default: allprofiles)
Disable firewall profile
Args:
profile (Optional[str]): The name of the profile to disable. Default is
``allprofiles``. Valid options are:
- allprofiles
- domainprofile
- privateprofile
- publicprofile
Returns:
bool: True if successful
Raises:
CommandExecutionError: If the command fails
CLI Example:
@ -64,15 +90,34 @@ def disable(profile='allprofiles'):
salt '*' firewall.disable
'''
cmd = ['netsh', 'advfirewall', 'set', profile, 'state', 'off']
return __salt__['cmd.run'](cmd, python_shell=False) == 'Ok.'
ret = __salt__['cmd.run_all'](cmd, python_shell=False, ignore_retcode=True)
if ret['retcode'] != 0:
raise CommandExecutionError(ret['stdout'])
return True
def enable(profile='allprofiles'):
'''
Enable firewall profile :param profile: (default: allprofiles)
.. versionadded:: 2015.5.0
Enable firewall profile
Args:
profile (Optional[str]): The name of the profile to enable. Default is
``allprofiles``. Valid options are:
- allprofiles
- domainprofile
- privateprofile
- publicprofile
Returns:
bool: True if successful
Raises:
CommandExecutionError: If the command fails
CLI Example:
.. code-block:: bash
@ -80,14 +125,28 @@ def enable(profile='allprofiles'):
salt '*' firewall.enable
'''
cmd = ['netsh', 'advfirewall', 'set', profile, 'state', 'on']
return __salt__['cmd.run'](cmd, python_shell=False) == 'Ok.'
ret = __salt__['cmd.run_all'](cmd, python_shell=False, ignore_retcode=True)
if ret['retcode'] != 0:
raise CommandExecutionError(ret['stdout'])
return True
def get_rule(name='all'):
'''
.. versionadded:: 2015.5.0
Get firewall rule(s) info
Display all matching rules as specified by name
Args:
name (Optional[str]): The full name of the rule. ``all`` will return all
rules. Default is ``all``
Returns:
dict: A dictionary of all rules or rules that match the name exactly
Raises:
CommandExecutionError: If the command fails
CLI Example:
@ -95,14 +154,13 @@ def get_rule(name='all'):
salt '*' firewall.get_rule 'MyAppPort'
'''
ret = {}
cmd = ['netsh', 'advfirewall', 'firewall', 'show', 'rule', 'name={0}'.format(name)]
ret[name] = __salt__['cmd.run'](cmd, python_shell=False)
cmd = ['netsh', 'advfirewall', 'firewall', 'show', 'rule',
'name="{0}"'.format(name)]
ret = __salt__['cmd.run_all'](cmd, python_shell=False, ignore_retcode=True)
if ret['retcode'] != 0:
raise CommandExecutionError(ret['stdout'])
if ret[name].strip() == 'No rules match the specified criteria.':
ret = False
return ret
return {name: ret['stdout']}
def add_rule(name, localport, protocol='tcp', action='allow', dir='in',
@ -110,7 +168,56 @@ def add_rule(name, localport, protocol='tcp', action='allow', dir='in',
'''
.. versionadded:: 2015.5.0
Add a new firewall rule
Add a new inbound or outbound rule to the firewall policy
Args:
name (str): The name of the rule. Must be unique and cannot be "all".
Required.
localport (int): The port the rule applies to. Must be a number between
0 and 65535. Can be a range. Can specify multiple ports separated by
commas. Required.
protocol (Optional[str]): The protocol. Can be any of the following:
- A number between 0 and 255
- icmpv4
- icmpv6
- tcp
- udp
- any
action (Optional[str]): The action the rule performs. Can be any of the
following:
- allow
- block
- bypass
dir (Optional[str]): The direction. Can be ``in`` or ``out``.
remoteip (Optional [str]): The remote IP. Can be any of the following:
- any
- localsubnet
- dns
- dhcp
- wins
- defaultgateway
- Any valid IPv4 address (192.168.0.12)
- Any valid IPv6 address (2002:9b3b:1a31:4:208:74ff:fe39:6c43)
- Any valid subnet (192.168.1.0/24)
- Any valid range of IP addresses (192.168.0.1-192.168.0.12)
- A list of valid IP addresses
Can be combinations of the above separated by commas.
Returns:
bool: True if successful
Raises:
CommandExecutionError: If the command fails
CLI Example:
@ -119,7 +226,6 @@ def add_rule(name, localport, protocol='tcp', action='allow', dir='in',
salt '*' firewall.add_rule 'test' '8080' 'tcp'
salt '*' firewall.add_rule 'test' '1' 'icmpv4'
salt '*' firewall.add_rule 'test_remote_ip' '8000' 'tcp' 'allow' 'in' '192.168.0.1'
'''
cmd = ['netsh', 'advfirewall', 'firewall', 'add', 'rule',
'name={0}'.format(name),
@ -128,42 +234,107 @@ def add_rule(name, localport, protocol='tcp', action='allow', dir='in',
'action={0}'.format(action),
'remoteip={0}'.format(remoteip)]
if 'icmpv4' not in protocol and 'icmpv6' not in protocol:
if protocol is None \
or ('icmpv4' not in protocol and 'icmpv6' not in protocol):
cmd.append('localport={0}'.format(localport))
ret = __salt__['cmd.run'](cmd, python_shell=False)
if isinstance(ret, six.string_types):
return ret.strip() == 'Ok.'
else:
log.error('firewall.add_rule failed: {0}'.format(ret))
return False
ret = __salt__['cmd.run_all'](cmd, python_shell=False, ignore_retcode=True)
if ret['retcode'] != 0:
raise CommandExecutionError(ret['stdout'])
return True
def delete_rule(name, localport, protocol='tcp', dir='in', remoteip='any'):
def delete_rule(name,
localport=None,
protocol=None,
dir=None,
remoteip=None):
'''
.. versionadded:: 2015.8.0
Delete an existing firewall rule
Delete an existing firewall rule identified by name and optionally by ports,
protocols, direction, and remote IP.
Args:
name (str): The name of the rule to delete. If the name ``all`` is used
you must specify additional parameters.
localport (Optional[str]): The port of the rule. Must specify a
protocol.
protocol (Optional[str]): The protocol of the rule.
dir (Optional[str]): The direction of the rule.
remoteip (Optional[str]): The remote IP of the rule.
Returns:
bool: True if successful
Raises:
CommandExecutionError: If the command fails
CLI Example:
.. code-block:: bash
# Delete incoming tcp port 8080 in the rule named 'test'
salt '*' firewall.delete_rule 'test' '8080' 'tcp' 'in'
# Delete the incoming tcp port 8000 from 192.168.0.1 in the rule named
# 'test_remote_ip`
salt '*' firewall.delete_rule 'test_remote_ip' '8000' 'tcp' 'in' '192.168.0.1'
# Delete all rules for local port 80:
salt '*' firewall.delete_rule all 80 tcp
# Delete a rule called 'allow80':
salt '*' firewall.delete_rule allow80
'''
cmd = ['netsh', 'advfirewall', 'firewall', 'delete', 'rule',
'name={0}'.format(name),
'protocol={0}'.format(protocol),
'dir={0}'.format(dir),
'remoteip={0}'.format(remoteip)]
'name={0}'.format(name)]
if protocol:
cmd.append('protocol={0}'.format(protocol))
if dir:
cmd.append('dir={0}'.format(dir))
if remoteip:
cmd.append('remoteip={0}'.format(remoteip))
if 'icmpv4' not in protocol and 'icmpv6' not in protocol:
cmd.append('localport={0}'.format(localport))
if protocol is None \
or ('icmpv4' not in protocol and 'icmpv6' not in protocol):
if localport:
cmd.append('localport={0}'.format(localport))
ret = __salt__['cmd.run'](cmd, python_shell=False)
if isinstance(ret, six.string_types):
return ret.endswith('Ok.')
else:
log.error('firewall.delete_rule failed: {0}'.format(ret))
ret = __salt__['cmd.run_all'](cmd, python_shell=False, ignore_retcode=True)
if ret['retcode'] != 0:
raise CommandExecutionError(ret['stdout'])
return True
def rule_exists(name):
'''
.. versionadded:: 2016.11.6
Checks if a firewall rule exists in the firewall policy
Args:
name (str): The name of the rule
Returns:
bool: True if exists, otherwise False
CLI Example:
.. code-block:: bash
# Is there a rule named RemoteDesktop
salt '*' firewall.rule_exists RemoteDesktop
'''
try:
get_rule(name)
return True
except CommandExecutionError:
return False

View File

@ -12,10 +12,6 @@ or for problem solving if your minion is having problems.
# Import Python Libs
from __future__ import absolute_import
import os
import ctypes
import sys
import time
import datetime
import logging
import subprocess
@ -23,14 +19,15 @@ log = logging.getLogger(__name__)
# Import Salt Libs
import salt.utils
import salt.ext.six as six
import salt.utils.event
from salt.utils.network import host_to_ips as _host_to_ips
from salt.utils import namespaced_function as _namespaced_function
# These imports needed for namespaced functions
# pylint: disable=W0611
from salt.modules.status import ping_master, time_
import copy
# pylint: enable=W0611
from salt.utils import namespaced_function as _namespaced_function
# Import 3rd Party Libs
try:
@ -43,6 +40,11 @@ try:
except ImportError:
HAS_WMI = False
HAS_PSUTIL = False
if salt.utils.is_windows():
import psutil
HAS_PSUTIL = True
__opts__ = {}
__virtualname__ = 'status'
@ -57,6 +59,9 @@ def __virtual__():
if not HAS_WMI:
return False, 'win_status.py: Requires WMI and WinAPI'
if not HAS_PSUTIL:
return False, 'win_status.py: Requires psutil'
# Namespace modules from `status.py`
global ping_master, time_
ping_master = _namespaced_function(ping_master, globals())
@ -81,10 +86,7 @@ def cpuload():
salt '*' status.cpuload
'''
# Pull in the information from WMIC
cmd = ['wmic', 'cpu', 'get', 'loadpercentage', '/value']
return int(__salt__['cmd.run'](cmd).split('=')[1])
return psutil.cpu_percent()
def diskusage(human_readable=False, path=None):
@ -105,29 +107,22 @@ def diskusage(human_readable=False, path=None):
if not path:
path = 'c:/'
# Credit for the source and ideas for this function:
# http://code.activestate.com/recipes/577972-disk-usage/?in=user-4178764
_, total, free = \
ctypes.c_ulonglong(), ctypes.c_ulonglong(), ctypes.c_longlong()
if sys.version_info >= (3, ) or isinstance(path, six.text_type):
fun = ctypes.windll.kernel32.GetDiskFreeSpaceExW
else:
fun = ctypes.windll.kernel32.GetDiskFreeSpaceExA
ret = fun(path, ctypes.byref(_), ctypes.byref(total), ctypes.byref(free))
if ret == 0:
raise ctypes.WinError()
used = total.value - free.value
disk_stats = psutil.disk_usage(path)
total_val = total.value
used_val = used
free_val = free.value
total_val = disk_stats.total
used_val = disk_stats.used
free_val = disk_stats.free
percent = disk_stats.percent
if human_readable:
total_val = _byte_calc(total_val)
used_val = _byte_calc(used_val)
free_val = _byte_calc(free_val)
return {'total': total_val, 'used': used_val, 'free': free_val}
return {'total': total_val,
'used': used_val,
'free': free_val,
'percent': percent}
def procs(count=False):
@ -178,16 +173,17 @@ def saltmem(human_readable=False):
salt '*' status.saltmem
salt '*' status.saltmem human_readable=True
'''
with salt.utils.winapi.Com():
wmi_obj = wmi.WMI()
result = wmi_obj.query(
'SELECT WorkingSet FROM Win32_PerfRawData_PerfProc_Process '
'WHERE IDProcess={0}'.format(os.getpid())
)
mem = int(result[0].wmi_property('WorkingSet').value)
if human_readable:
return _byte_calc(mem)
return mem
# psutil.Process defaults to current process (`os.getpid()`)
p = psutil.Process()
# Use oneshot to get a snapshot
with p.oneshot():
mem = p.memory_info().rss
if human_readable:
return _byte_calc(mem)
return mem
def uptime(human_readable=False):
@ -206,15 +202,8 @@ def uptime(human_readable=False):
salt '*' status.uptime
salt '*' status.uptime human_readable=True
'''
# Open up a subprocess to get information from WMIC
cmd = ['wmic', 'os', 'get', 'lastbootuptime', '/value']
startup_time = __salt__['cmd.run'](cmd).split('=')[1][:14]
# Convert to time struct
startup_time = time.strptime(startup_time, '%Y%m%d%H%M%S')
# Convert to datetime object
startup_time = datetime.datetime(*startup_time[:6])
# Get startup time
startup_time = datetime.datetime.fromtimestamp(psutil.boot_time())
# Subtract startup time from current time to get the uptime of the system
uptime = datetime.datetime.now() - startup_time

View File

@ -200,16 +200,54 @@ def returner(ret):
mdb.saltReturns.insert(sdata.copy())
def _safe_copy(dat):
''' mongodb doesn't allow '.' in keys, but does allow unicode equivs.
Apparently the docs suggest using escaped unicode full-width
encodings. *sigh*
\\ --> \\\\
$ --> \\\\u0024
. --> \\\\u002e
Personally, I prefer URL encodings,
\\ --> %5c
$ --> %24
. --> %2e
Which means also escaping '%':
% -> %25
'''
if isinstance(dat, dict):
ret = {}
for k in dat:
r = k.replace('%', '%25').replace('\\', '%5c').replace('$', '%24').replace('.', '%2e')
if r != k:
log.debug('converting dict key from {0} to {1} for mongodb'.format(k, r))
ret[r] = _safe_copy(dat[k])
return ret
if isinstance(dat, (list, tuple)):
return [_safe_copy(i) for i in dat]
return dat
def save_load(jid, load, minions=None):
'''
Save the load for a given job id
'''
conn, mdb = _get_conn(ret=None)
to_save = _safe_copy(load)
if float(version) > 2.3:
#using .copy() to ensure original data for load is unchanged
mdb.jobs.insert_one(load.copy())
mdb.jobs.insert_one(to_save)
else:
mdb.jobs.insert(load.copy())
mdb.jobs.insert(to_save)
def save_minions(jid, minions, syndic_id=None): # pylint: disable=unused-argument

View File

@ -128,7 +128,7 @@ def _get_serv(ret=None):
def _get_ttl():
return __opts__['keep_jobs'] * 3600
return __opts__.get('keep_jobs', 24) * 3600
def returner(ret):

View File

@ -62,7 +62,7 @@ def _update_config(template_name,
template_mode='755',
saltenv=None,
template_engine='jinja',
skip_verify=True,
skip_verify=False,
defaults=None,
test=False,
commit=True,

View File

@ -2,6 +2,10 @@
'''
State for configuring Windows Firewall
'''
from __future__ import absolute_import
# Import Salt libs
from salt.exceptions import CommandExecutionError, SaltInvocationError
def __virtual__():
@ -14,59 +18,143 @@ def __virtual__():
def disabled(name='allprofiles'):
'''
Disable all the firewall profiles (Windows only)
Args:
profile (Optional[str]): The name of the profile to disable. Default is
``allprofiles``. Valid options are:
- allprofiles
- domainprofile
- privateprofile
- publicprofile
Example:
.. code-block:: yaml
# To disable the domain profile
disable_domain:
win_firewall.disabled:
- name: domainprofile
# To disable all profiles
disable_all:
win_firewall.disabled:
- name: allprofiles
'''
ret = {'name': name,
'result': True,
'changes': {},
'comment': ''}
# Determine what to do
action = False
check_name = None
if name != 'allprofiles':
check_name = True
profile_map = {'domainprofile': 'Domain',
'privateprofile': 'Private',
'publicprofile': 'Public',
'allprofiles': 'All'}
# Make sure the profile name is valid
if name not in profile_map:
raise SaltInvocationError('Invalid profile name: {0}'.format(name))
current_config = __salt__['firewall.get_config']()
if check_name and name not in current_config:
if name != 'allprofiles' and profile_map[name] not in current_config:
ret['result'] = False
ret['comment'] = 'Profile {0} does not exist in firewall.get_config'.format(name)
ret['comment'] = 'Profile {0} does not exist in firewall.get_config' \
''.format(name)
return ret
for key in current_config:
if current_config[key]:
if check_name and key != name:
continue
action = True
ret['changes'] = {'fw': 'disabled'}
break
if name == 'allprofiles' or key == profile_map[name]:
ret['changes'][key] = 'disabled'
if __opts__['test']:
ret['result'] = not action or None
ret['result'] = not ret['changes'] or None
ret['comment'] = ret['changes']
ret['changes'] = {}
return ret
# Disable it
if action:
ret['result'] = __salt__['firewall.disable'](name)
if not ret['result']:
ret['comment'] = 'Could not disable the FW'
if check_name:
msg = 'Firewall profile {0} could not be disabled'.format(name)
else:
msg = 'Could not disable the FW'
ret['comment'] = msg
if ret['changes']:
try:
ret['result'] = __salt__['firewall.disable'](name)
except CommandExecutionError:
ret['comment'] = 'Firewall Profile {0} could not be disabled' \
''.format(profile_map[name])
else:
if check_name:
msg = 'Firewall profile {0} is disabled'.format(name)
else:
if name == 'allprofiles':
msg = 'All the firewall profiles are disabled'
else:
msg = 'Firewall profile {0} is disabled'.format(name)
ret['comment'] = msg
return ret
def add_rule(name, localport, protocol="tcp", action="allow", dir="in"):
def add_rule(name,
localport,
protocol='tcp',
action='allow',
dir='in',
remoteip='any'):
'''
Add a new firewall rule (Windows only)
Add a new inbound or outbound rule to the firewall policy
Args:
name (str): The name of the rule. Must be unique and cannot be "all".
Required.
localport (int): The port the rule applies to. Must be a number between
0 and 65535. Can be a range. Can specify multiple ports separated by
commas. Required.
protocol (Optional[str]): The protocol. Can be any of the following:
- A number between 0 and 255
- icmpv4
- icmpv6
- tcp
- udp
- any
action (Optional[str]): The action the rule performs. Can be any of the
following:
- allow
- block
- bypass
dir (Optional[str]): The direction. Can be ``in`` or ``out``.
remoteip (Optional [str]): The remote IP. Can be any of the following:
- any
- localsubnet
- dns
- dhcp
- wins
- defaultgateway
- Any valid IPv4 address (192.168.0.12)
- Any valid IPv6 address (2002:9b3b:1a31:4:208:74ff:fe39:6c43)
- Any valid subnet (192.168.1.0/24)
- Any valid range of IP addresses (192.168.0.1-192.168.0.12)
- A list of valid IP addresses
Can be combinations of the above separated by commas.
.. versionadded:: 2016.11.6
Example:
.. code-block:: yaml
open_smb_port:
win_firewall.add_rule:
- name: SMB (445)
- localport: 445
- protocol: tcp
- action: allow
'''
ret = {'name': name,
'result': True,
@ -74,23 +162,24 @@ def add_rule(name, localport, protocol="tcp", action="allow", dir="in"):
'comment': ''}
# Check if rule exists
commit = False
current_rules = __salt__['firewall.get_rule'](name)
if not current_rules:
commit = True
if not __salt__['firewall.rule_exists'](name):
ret['changes'] = {'new rule': name}
else:
ret['comment'] = 'A rule with that name already exists'
return ret
if __opts__['test']:
ret['result'] = not commit or None
ret['result'] = not ret['changes'] or None
ret['comment'] = ret['changes']
ret['changes'] = {}
return ret
# Add rule
if commit:
ret['result'] = __salt__['firewall.add_rule'](name, localport, protocol, action, dir)
if not ret['result']:
ret['comment'] = 'Could not add rule'
else:
ret['comment'] = 'A rule with that name already exists'
try:
__salt__['firewall.add_rule'](
name, localport, protocol, action, dir, remoteip)
except CommandExecutionError:
ret['comment'] = 'Could not add rule'
return ret
@ -98,50 +187,74 @@ def add_rule(name, localport, protocol="tcp", action="allow", dir="in"):
def enabled(name='allprofiles'):
'''
Enable all the firewall profiles (Windows only)
Args:
profile (Optional[str]): The name of the profile to enable. Default is
``allprofiles``. Valid options are:
- allprofiles
- domainprofile
- privateprofile
- publicprofile
Example:
.. code-block:: yaml
# To enable the domain profile
enable_domain:
win_firewall.enabled:
- name: domainprofile
# To enable all profiles
enable_all:
win_firewall.enabled:
- name: allprofiles
'''
ret = {'name': name,
'result': True,
'changes': {},
'comment': ''}
# Determine what to do
action = False
check_name = None
if name != 'allprofiles':
check_name = True
profile_map = {'domainprofile': 'Domain',
'privateprofile': 'Private',
'publicprofile': 'Public',
'allprofiles': 'All'}
# Make sure the profile name is valid
if name not in profile_map:
raise SaltInvocationError('Invalid profile name: {0}'.format(name))
current_config = __salt__['firewall.get_config']()
if check_name and name not in current_config:
if name != 'allprofiles' and profile_map[name] not in current_config:
ret['result'] = False
ret['comment'] = 'Profile {0} does not exist in firewall.get_config'.format(name)
ret['comment'] = 'Profile {0} does not exist in firewall.get_config' \
''.format(name)
return ret
for key in current_config:
if not current_config[key]:
if check_name and key != name:
continue
action = True
ret['changes'] = {'fw': 'enabled'}
break
if name == 'allprofiles' or key == profile_map[name]:
ret['changes'][key] = 'enabled'
if __opts__['test']:
ret['result'] = not action or None
ret['result'] = not ret['changes'] or None
ret['comment'] = ret['changes']
ret['changes'] = {}
return ret
# Disable it
if action:
ret['result'] = __salt__['firewall.enable'](name)
if not ret['result']:
if check_name:
msg = 'Firewall profile {0} could not be enabled'.format(name)
else:
msg = 'Could not enable the FW'
ret['comment'] = msg
# Enable it
if ret['changes']:
try:
ret['result'] = __salt__['firewall.enable'](name)
except CommandExecutionError:
ret['comment'] = 'Firewall Profile {0} could not be enabled' \
''.format(profile_map[name])
else:
if check_name:
msg = 'Firewall profile {0} is enabled'.format(name)
else:
if name == 'allprofiles':
msg = 'All the firewall profiles are enabled'
else:
msg = 'Firewall profile {0} is enabled'.format(name)
ret['comment'] = msg
return ret

View File

@ -15,7 +15,7 @@ import salt.utils
# Import 3rd-party libs
import salt.ext.six as six
from salt.ext.six import BytesIO, StringIO
from salt.ext.six import BytesIO
class GzipFile(gzip.GzipFile):
@ -90,7 +90,7 @@ def compress_file(fh_, compresslevel=9, chunk_size=1048576):
raise ValueError('chunk_size must be an integer')
try:
while bytes_read == chunk_size:
buf = StringIO()
buf = BytesIO()
with open_fileobj(buf, 'wb', compresslevel) as ogz:
try:
bytes_read = ogz.write(fh_.read(chunk_size))

View File

@ -622,16 +622,20 @@ class Schema(six.with_metaclass(SchemaMeta, object)):
if properties:
serialized['properties'] = properties
# Update the serialized object with any items to include after properties
# Update the serialized object with any items to include after properties.
# Do not overwrite properties already existing in the serialized dict.
if cls.after_items_update:
after_items_update = {}
for entry in cls.after_items_update:
name, data = next(six.iteritems(entry))
if name in after_items_update:
after_items_update[name].extend(data)
else:
after_items_update[name] = data
serialized.update(after_items_update)
for name, data in six.iteritems(entry):
if name in after_items_update:
if isinstance(after_items_update[name], list):
after_items_update[name].extend(data)
else:
after_items_update[name] = data
if after_items_update:
after_items_update.update(serialized)
serialized = after_items_update
if required:
# Only include required if not empty

View File

@ -47,7 +47,10 @@ def setup_handlers():
log.warning('Failed to connect to log server')
return
finally:
sock.shutdown(socket.SHUT_RDWR)
try:
sock.shutdown(socket.SHUT_RDWR)
except OSError:
pass
sock.close()
queue = Queue()

220
tests/unit/cache/test_cache.py vendored Normal file
View File

@ -0,0 +1,220 @@
# -*- coding: utf-8 -*-
'''
unit tests for the localfs cache
'''
# Import Python libs
from __future__ import absolute_import
# Import Salt Testing libs
# import integration
from tests.support.unit import skipIf, TestCase
from tests.support.mock import (
NO_MOCK,
NO_MOCK_REASON,
patch,
)
# Import Salt libs
import salt.payload
import salt.utils
import salt.cache
class CacheFunctionsTest(TestCase):
'''
Validate the cache package functions.
'''
def setUp(self):
self.opts = {'cache': 'localfs',
'memcache_expire_seconds': 0,
'memcache_max_items': 0,
'memcache_full_cleanup': False,
'memcache_debug': False}
def test_factory_cache(self):
ret = salt.cache.factory(self.opts)
self.assertIsInstance(ret, salt.cache.Cache)
def test_factory_memcache(self):
self.opts['memcache_expire_seconds'] = 10
ret = salt.cache.factory(self.opts)
self.assertIsInstance(ret, salt.cache.MemCache)
@skipIf(NO_MOCK, NO_MOCK_REASON)
class MemCacheTest(TestCase):
'''
Validate Cache class methods
'''
@patch('salt.payload.Serial')
def setUp(self, serial_mock): # pylint: disable=W0221
salt.cache.MemCache.data = {}
self.opts = {'cache': 'fake_driver',
'memcache_expire_seconds': 10,
'memcache_max_items': 3,
'memcache_full_cleanup': False,
'memcache_debug': False}
self.cache = salt.cache.factory(self.opts)
@patch('salt.cache.Cache.fetch', return_value='fake_data')
@patch('salt.loader.cache', return_value={})
def test_fetch(self, loader_mock, cache_fetch_mock):
# Fetch value, it will be kept in cache.
with patch('time.time', return_value=0):
ret = self.cache.fetch('bank', 'key')
self.assertEqual(ret, 'fake_data')
self.assertDictEqual(salt.cache.MemCache.data, {
'fake_driver': {
('bank', 'key'): [0, 'fake_data'],
}})
cache_fetch_mock.assert_called_once_with('bank', 'key')
cache_fetch_mock.reset_mock()
# Fetch again, cached value is used, time updated.
with patch('time.time', return_value=1):
ret = self.cache.fetch('bank', 'key')
self.assertEqual(ret, 'fake_data')
self.assertDictEqual(salt.cache.MemCache.data, {
'fake_driver': {
('bank', 'key'): [1, 'fake_data'],
}})
cache_fetch_mock.assert_not_called()
# Fetch after expire
with patch('time.time', return_value=12):
ret = self.cache.fetch('bank', 'key')
self.assertEqual(ret, 'fake_data')
self.assertDictEqual(salt.cache.MemCache.data, {
'fake_driver': {
('bank', 'key'): [12, 'fake_data'],
}})
cache_fetch_mock.assert_called_once_with('bank', 'key')
cache_fetch_mock.reset_mock()
@patch('salt.cache.Cache.store')
@patch('salt.loader.cache', return_value={})
def test_store(self, loader_mock, cache_store_mock):
# Fetch value, it will be kept in cache.
with patch('time.time', return_value=0):
self.cache.store('bank', 'key', 'fake_data')
self.assertDictEqual(salt.cache.MemCache.data, {
'fake_driver': {
('bank', 'key'): [0, 'fake_data'],
}})
cache_store_mock.assert_called_once_with('bank', 'key', 'fake_data')
cache_store_mock.reset_mock()
# Store another value.
with patch('time.time', return_value=1):
self.cache.store('bank', 'key2', 'fake_data2')
self.assertDictEqual(salt.cache.MemCache.data, {
'fake_driver': {
('bank', 'key'): [0, 'fake_data'],
('bank', 'key2'): [1, 'fake_data2'],
}})
cache_store_mock.assert_called_once_with('bank', 'key2', 'fake_data2')
@patch('salt.cache.Cache.store')
@patch('salt.cache.Cache.flush')
@patch('salt.loader.cache', return_value={})
def test_flush(self, loader_mock, cache_flush_mock, cache_store_mock):
# Flush non-existing bank
self.cache.flush('bank')
self.assertDictEqual(salt.cache.MemCache.data, {'fake_driver': {}})
cache_flush_mock.assert_called_once_with('bank', None)
cache_flush_mock.reset_mock()
# Flush non-existing key
self.cache.flush('bank', 'key')
self.assertDictEqual(salt.cache.MemCache.data, {'fake_driver': {}})
cache_flush_mock.assert_called_once_with('bank', 'key')
cache_flush_mock.reset_mock()
# Flush existing key
with patch('time.time', return_value=0):
self.cache.store('bank', 'key', 'fake_data')
self.assertEqual(salt.cache.MemCache.data['fake_driver'][('bank', 'key')],
[0, 'fake_data'])
self.assertDictEqual(salt.cache.MemCache.data, {
'fake_driver': {
('bank', 'key'): [0, 'fake_data'],
}})
self.cache.flush('bank', 'key')
self.assertDictEqual(salt.cache.MemCache.data, {'fake_driver': {}})
cache_flush_mock.assert_called_once_with('bank', 'key')
cache_flush_mock.reset_mock()
@patch('salt.cache.Cache.store')
@patch('salt.loader.cache', return_value={})
def test_max_items(self, loader_mock, cache_store_mock):
# Put MAX=3 values
with patch('time.time', return_value=0):
self.cache.store('bank1', 'key1', 'fake_data11')
with patch('time.time', return_value=1):
self.cache.store('bank1', 'key2', 'fake_data12')
with patch('time.time', return_value=2):
self.cache.store('bank2', 'key1', 'fake_data21')
self.assertDictEqual(salt.cache.MemCache.data['fake_driver'], {
('bank1', 'key1'): [0, 'fake_data11'],
('bank1', 'key2'): [1, 'fake_data12'],
('bank2', 'key1'): [2, 'fake_data21'],
})
# Put one more and check the oldest was removed
with patch('time.time', return_value=3):
self.cache.store('bank2', 'key2', 'fake_data22')
self.assertDictEqual(salt.cache.MemCache.data['fake_driver'], {
('bank1', 'key2'): [1, 'fake_data12'],
('bank2', 'key1'): [2, 'fake_data21'],
('bank2', 'key2'): [3, 'fake_data22'],
})
@patch('salt.cache.Cache.store')
@patch('salt.loader.cache', return_value={})
def test_full_cleanup(self, loader_mock, cache_store_mock):
# Enable full cleanup
self.cache.cleanup = True
# Put MAX=3 values
with patch('time.time', return_value=0):
self.cache.store('bank1', 'key1', 'fake_data11')
with patch('time.time', return_value=1):
self.cache.store('bank1', 'key2', 'fake_data12')
with patch('time.time', return_value=2):
self.cache.store('bank2', 'key1', 'fake_data21')
self.assertDictEqual(salt.cache.MemCache.data['fake_driver'], {
('bank1', 'key1'): [0, 'fake_data11'],
('bank1', 'key2'): [1, 'fake_data12'],
('bank2', 'key1'): [2, 'fake_data21'],
})
# Put one more and check all expired was removed
with patch('time.time', return_value=12):
self.cache.store('bank2', 'key2', 'fake_data22')
self.assertDictEqual(salt.cache.MemCache.data['fake_driver'], {
('bank2', 'key1'): [2, 'fake_data21'],
('bank2', 'key2'): [12, 'fake_data22'],
})
@patch('salt.cache.Cache.fetch', return_value='fake_data')
@patch('salt.loader.cache', return_value={})
def test_fetch_debug(self, loader_mock, cache_fetch_mock):
# Recreate cache with debug enabled
self.opts['memcache_debug'] = True
self.cache = salt.cache.factory(self.opts)
# Fetch 2 values (no cache hit)
with patch('time.time', return_value=0):
ret = self.cache.fetch('bank', 'key1')
with patch('time.time', return_value=1):
ret = self.cache.fetch('bank', 'key2')
# Fetch 3 times (cache hit)
with patch('time.time', return_value=2):
ret = self.cache.fetch('bank', 'key2')
with patch('time.time', return_value=3):
ret = self.cache.fetch('bank', 'key1')
with patch('time.time', return_value=4):
ret = self.cache.fetch('bank', 'key1')
# Fetch an expired value (no cache hit)
with patch('time.time', return_value=13):
ret = self.cache.fetch('bank', 'key2')
# Check debug data
self.assertEqual(self.cache.call, 6)
self.assertEqual(self.cache.hit, 3)

View File

@ -17,7 +17,6 @@ from salt.config.schemas.minion import MinionConfiguration
from salt.utils.versions import LooseVersion as _LooseVersion
# Import 3rd-party libs
import salt.ext.six as six
try:
import jsonschema
import jsonschema.exceptions
@ -28,8 +27,7 @@ except ImportError:
JSONSCHEMA_VERSION = _LooseVersion('0')
@skipIf(six.PY3, 'Tests disabled under Python 3')
class RoosterEntryConfigTest(TestCase):
class RosterEntryConfigTest(TestCase):
def test_config(self):
config = ssh_schemas.RosterEntryConfig()

View File

@ -189,12 +189,15 @@ class DjangomodCliCommandTestCase(TestCase, LoaderModuleMockMixin):
djangomod.createsuperuser(
'settings.py', 'testuser', 'user@example.com'
)
mock.assert_called_once_with(
'django-admin.py createsuperuser --settings=settings.py '
'--noinput --username=testuser --email=user@example.com',
python_shell=False,
env=None
)
mock.assert_called_once()
args, kwargs = mock.call_args
# cmdline arguments are extracted from a kwargs dict so order isn't guaranteed.
self.assertEqual(len(args), 1)
self.assertTrue(args[0].startswith('django-admin.py createsuperuser --'))
self.assertEqual(set(args[0].split()),
set('django-admin.py createsuperuser --settings=settings.py --noinput '
'--username=testuser --email=user@example.com'.split()))
self.assertDictEqual(kwargs, {'python_shell': False, 'env': None})
def no_test_loaddata(self):
mock = MagicMock()

View File

@ -96,8 +96,8 @@ class PillarModuleTestCase(TestCase, LoaderModuleMockMixin):
if default_type == data_type:
continue
self.assertEqual(
pillarmod.get(item, default=defaults[default_type], merge=True),
pillarmod.__pillar__[item]
pillarmod.get(data_type, default=defaults[default_type], merge=True),
pillarmod.__pillar__[data_type]
)
# Test recursive dict merging

View File

@ -89,9 +89,14 @@ class SSHAuthKeyTestCase(TestCase, LoaderModuleMockMixin):
'/w4yCE6gbODqnTWlg7+wC604ydGXA8VJiS5ap43JXiUFFAaQ=='
options = 'command="/usr/local/lib/ssh-helper"'
email = 'github.com'
empty_line = '\n'
comment_line = '# this is a comment \n'
# Write out the authorized key to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, mode='w+')
# Add comment
temp_file.write(comment_line)
# Add empty line for #41335
temp_file.write(empty_line)
temp_file.write('{0} {1} {2} {3}'.format(options, enc, key, email))
temp_file.close()
@ -131,3 +136,5 @@ class SSHAuthKeyTestCase(TestCase, LoaderModuleMockMixin):
self.assertIn(key, file_txt)
self.assertIn('{0} '.format(','.join(options)), file_txt)
self.assertIn(email, file_txt)
self.assertIn(empty_line, file_txt)
self.assertIn(comment_line, file_txt)

View File

@ -1,203 +0,0 @@
# -*- coding: utf-8 -*-
'''
:codeauthor: :email:`Jayesh Kariya <jayeshk@saltstack.com>`
'''
# Import Python Libs
from __future__ import absolute_import
# Import Salt Testing Libs
from tests.support.mixins import LoaderModuleMockMixin
from tests.support.unit import TestCase
from tests.support.mock import (
MagicMock,
patch,
call
)
# Import Salt Libs
import salt.modules.win_firewall as win_firewall
class WinFirewallTestCase(TestCase, LoaderModuleMockMixin):
'''
Test cases for salt.modules.win_firewall
'''
def setup_loader_modules(self):
return {win_firewall: {}}
# 'get_config' function tests: 1
def test_get_config(self):
'''
Test if it get the status of all the firewall profiles
'''
mock_cmd = MagicMock(return_value='')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertDictEqual(win_firewall.get_config(), {})
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'show', 'allprofiles'], python_shell=False)
# 'disable' function tests: 1
def test_disable(self):
'''
Test if it disable firewall profile :(default: allprofiles)
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.disable())
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'set', 'allprofiles', 'state', 'off'],
python_shell=False)
# 'enable' function tests: 1
def test_enable(self):
'''
Test if it enable firewall profile :(default: allprofiles)
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.enable())
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'set', 'allprofiles', 'state', 'on'],
python_shell=False)
# 'get_rule' function tests: 1
def test_get_rule(self):
'''
Test if it get firewall rule(s) info
'''
val = 'No rules match the specified criteria.'
mock_cmd = MagicMock(side_effect=['salt', val])
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertDictEqual(win_firewall.get_rule(), {'all': 'salt'})
self.assertFalse(win_firewall.get_rule())
calls = [
call(['netsh', 'advfirewall', 'firewall', 'show', 'rule', 'name=all'], python_shell=False),
call(['netsh', 'advfirewall', 'firewall', 'show', 'rule', 'name=all'], python_shell=False)
]
mock_cmd.assert_has_calls(calls)
# 'add_rule' function tests: 1
def test_add_rule(self):
'''
Test if it add a new firewall rule
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.add_rule("test", "8080"))
mock_cmd.assert_called_once_with(['netsh', 'advfirewall',
'firewall', 'add', 'rule',
'name=test', 'protocol=tcp',
'dir=in', 'action=allow',
'remoteip=any',
'localport=8080'],
python_shell=False)
def test_add_rule_icmp4(self):
'''
Test if it add a new firewall rule
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.add_rule("test", "1", protocol='icmpv4'))
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'firewall', 'add', 'rule',
'name=test',
'protocol=icmpv4',
'dir=in',
'action=allow',
'remoteip=any'],
python_shell=False)
def test_add_rule_icmp6(self):
'''
Test if it add a new firewall rule
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.add_rule("test", "1", protocol='icmpv6'))
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'firewall', 'add', 'rule',
'name=test',
'protocol=icmpv6',
'dir=in',
'action=allow',
'remoteip=any'],
python_shell=False)
def test_add_rule_icmp4_any(self):
'''
Test if it add a new firewall rule
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.add_rule("test", "1", protocol='icmpv4:any,any'))
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'firewall', 'add', 'rule',
'name=test',
'protocol=icmpv4:any,any',
'dir=in',
'action=allow',
'remoteip=any'],
python_shell=False)
# 'delete_rule' function tests: 1
def test_delete_rule(self):
'''
Test if it delete an existing firewall rule
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.delete_rule("test", "8080", "tcp",
"in"))
mock_cmd.assert_called_once_with(['netsh', 'advfirewall',
'firewall', 'delete', 'rule',
'name=test', 'protocol=tcp',
'dir=in',
'remoteip=any',
'localport=8080'],
python_shell=False)
def test_delete_rule_icmp4(self):
'''
Test if it deletes a new firewall rule
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.delete_rule("test", "1", protocol='icmpv4'))
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'firewall', 'delete', 'rule',
'name=test',
'protocol=icmpv4',
'dir=in',
'remoteip=any'],
python_shell=False)
def test_delete_rule_icmp6(self):
'''
Test if it deletes a new firewall rule
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.delete_rule("test", "1", protocol='icmpv6'))
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'firewall', 'delete', 'rule',
'name=test',
'protocol=icmpv6',
'dir=in',
'remoteip=any'],
python_shell=False)
def test_delete_rule_icmp4_any(self):
'''
Test if it deletes a new firewall rule
'''
mock_cmd = MagicMock(return_value='Ok.')
with patch.dict(win_firewall.__salt__, {'cmd.run': mock_cmd}):
self.assertTrue(win_firewall.delete_rule("test", "1", protocol='icmpv4:any,any'))
mock_cmd.assert_called_once_with(['netsh', 'advfirewall', 'firewall', 'delete', 'rule',
'name=test',
'protocol=icmpv4:any,any',
'dir=in',
'remoteip=any'],
python_shell=False)

View File

@ -1,67 +0,0 @@
# -*- coding: utf-8 -*-
'''
:codeauthor: :email:`Rahul Handay <rahulha@saltstack.com>`
'''
# Import Python Libs
from __future__ import absolute_import
# Import Salt Testing Libs
from tests.support.mixins import LoaderModuleMockMixin
from tests.support.unit import TestCase, skipIf
from tests.support.mock import (
MagicMock,
patch,
NO_MOCK,
NO_MOCK_REASON
)
# Import Salt Libs
import salt.states.win_firewall as win_firewall
@skipIf(NO_MOCK, NO_MOCK_REASON)
class WinFirewallTestCase(TestCase, LoaderModuleMockMixin):
'''
Validate the win_firewall state
'''
def setup_loader_modules(self):
return {win_firewall: {}}
def test_disabled(self):
'''
Test to disable all the firewall profiles (Windows only)
'''
ret = {'name': 'salt',
'changes': {},
'result': True,
'comment': ''}
mock = MagicMock(return_value={'salt': '', 'foo': ''})
with patch.dict(win_firewall.__salt__, {'firewall.get_config': mock}):
with patch.dict(win_firewall.__opts__, {'test': True}):
self.assertDictEqual(win_firewall.disabled('salt'), ret)
with patch.dict(win_firewall.__opts__, {'test': False}):
ret.update({'comment': 'Firewall profile salt is disabled',
'result': True})
self.assertDictEqual(win_firewall.disabled('salt'), ret)
def test_add_rule(self):
'''
Test to add a new firewall rule (Windows only)
'''
ret = {'name': 'salt',
'changes': {'new rule': 'salt'},
'result': None,
'comment': ''}
mock = MagicMock(return_value=False)
add_rule_mock = MagicMock(return_value=True)
with patch.dict(win_firewall.__salt__, {'firewall.get_rule': mock,
'firewall.add_rule': add_rule_mock}):
with patch.dict(win_firewall.__opts__, {'test': True}):
self.assertDictEqual(win_firewall.add_rule('salt', 'stack'), ret)
with patch.dict(win_firewall.__opts__, {'test': False}):
with patch.dict(win_firewall.__opts__, {'test': False}):
ret.update({'result': True})
result = win_firewall.add_rule('salt', 'stack')
self.assertDictEqual(result, ret)