Merge pull request #46023 from bloomberg/parallel-orch

add parallel support for orchestrations
This commit is contained in:
Nicole Thomas 2018-04-10 15:26:03 -04:00 committed by GitHub
commit 8adaf7f526
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 132 additions and 32 deletions

View File

@ -1429,6 +1429,9 @@ def runner(name, arg=None, kwarg=None, full_return=False, saltenv='base', jid=No
if 'saltenv' in aspec.args:
kwarg['saltenv'] = saltenv
if name in ['state.orchestrate', 'state.orch', 'state.sls']:
kwarg['orchestration_jid'] = jid
if jid:
salt.utils.event.fire_args(
__opts__,

View File

@ -229,7 +229,7 @@ class Runner(RunnerClient):
async_pub = self._gen_async_pub()
self.jid = async_pub['jid']
if low['fun'] in ('state.orchestrate', 'state.orch'):
if low['fun'] in ['state.orchestrate', 'state.orch', 'state.sls']:
low['kwarg']['orchestration_jid'] = async_pub['jid']
# Run the runner!

View File

@ -85,6 +85,7 @@ def orchestrate(mods,
saltenv=saltenv,
pillarenv=pillarenv,
pillar_enc=pillar_enc,
__pub_jid=orchestration_jid,
orchestration_jid=orchestration_jid)
ret = {'data': {minion.opts['id']: running}, 'outputter': 'highstate'}
res = salt.utils.check_state_result(ret['data'])

View File

@ -40,6 +40,7 @@ import salt.utils.process
import salt.utils.files
import salt.syspaths as syspaths
from salt.utils import immutabletypes
from salt.serializers.msgpack import serialize as msgpack_serialize, deserialize as msgpack_deserialize
from salt.template import compile_template, compile_template_str
from salt.exceptions import (
SaltException,
@ -56,7 +57,6 @@ import salt.utils.yamlloader as yamlloader
import salt.ext.six as six
from salt.ext.six.moves import map, range, reload_module
# pylint: enable=import-error,no-name-in-module,redefined-builtin
import msgpack
log = logging.getLogger(__name__)
@ -174,7 +174,7 @@ def _calculate_fake_duration():
start_time = local_start_time.time().isoformat()
delta = (utc_finish_time - utc_start_time)
# duration in milliseconds.microseconds
duration = (delta.seconds * 1000000 + delta.microseconds)/1000.0
duration = (delta.seconds * 1000000 + delta.microseconds) / 1000.0
return start_time, duration
@ -1693,27 +1693,20 @@ class State(object):
errors.extend(req_in_errors)
return req_in_high, errors
def _call_parallel_target(self, cdata, low):
def _call_parallel_target(self, name, cdata, low):
'''
The target function to call that will create the parallel thread/process
'''
# we need to re-record start/end duration here because it is impossible to
# correctly calculate further down the chain
utc_start_time = datetime.datetime.utcnow()
tag = _gen_tag(low)
try:
ret = self.states[cdata['full']](*cdata['args'],
**cdata['kwargs'])
except Exception:
trb = traceback.format_exc()
# There are a number of possibilities to not have the cdata
# populated with what we might have expected, so just be smart
# enough to not raise another KeyError as the name is easily
# guessable and fallback in all cases to present the real
# exception to the user
if len(cdata['args']) > 0:
name = cdata['args'][0]
elif 'name' in cdata['kwargs']:
name = cdata['kwargs']['name']
else:
name = low.get('name', low.get('__id__'))
ret = {
'result': False,
'name': name,
@ -1721,6 +1714,13 @@ class State(object):
'comment': 'An exception occurred in this state: {0}'.format(
trb)
}
utc_finish_time = datetime.datetime.utcnow()
delta = (utc_finish_time - utc_start_time)
# duration in milliseconds.microseconds
duration = (delta.seconds * 1000000 + delta.microseconds) / 1000.0
ret['duration'] = duration
troot = os.path.join(self.opts['cachedir'], self.jid)
tfile = os.path.join(troot, _clean_tag(tag))
if not os.path.isdir(troot):
@ -1731,17 +1731,26 @@ class State(object):
# and the attempt, we are safe to pass
pass
with salt.utils.fopen(tfile, 'wb+') as fp_:
fp_.write(msgpack.dumps(ret))
fp_.write(msgpack_serialize(ret))
def call_parallel(self, cdata, low):
'''
Call the state defined in the given cdata in parallel
'''
# There are a number of possibilities to not have the cdata
# populated with what we might have expected, so just be smart
# enough to not raise another KeyError as the name is easily
# guessable and fallback in all cases to present the real
# exception to the user
name = (cdata.get('args') or [None])[0] or cdata['kwargs'].get('name')
if not name:
name = low.get('name', low.get('__id__'))
proc = salt.utils.process.MultiprocessingProcess(
target=self._call_parallel_target,
args=(cdata, low))
args=(name, cdata, low))
proc.start()
ret = {'name': cdata['args'][0],
ret = {'name': name,
'result': None,
'changes': {},
'comment': 'Started in a separate process',
@ -1879,12 +1888,10 @@ class State(object):
# enough to not raise another KeyError as the name is easily
# guessable and fallback in all cases to present the real
# exception to the user
if len(cdata['args']) > 0:
name = cdata['args'][0]
elif 'name' in cdata['kwargs']:
name = cdata['kwargs']['name']
else:
name = (cdata.get('args') or [None])[0] or cdata['kwargs'].get('name')
if not name:
name = low.get('name', low.get('__id__'))
ret = {
'result': False,
'name': name,
@ -1923,7 +1930,7 @@ class State(object):
ret['start_time'] = local_start_time.time().isoformat()
delta = (utc_finish_time - utc_start_time)
# duration in milliseconds.microseconds
duration = (delta.seconds * 1000000 + delta.microseconds)/1000.0
duration = (delta.seconds * 1000000 + delta.microseconds) / 1000.0
ret['duration'] = duration
ret['__id__'] = low['__id__']
log.info(
@ -2048,7 +2055,7 @@ class State(object):
while True:
if self.reconcile_procs(running):
break
time.sleep(0.01)
time.sleep(0.0001)
ret = dict(list(disabled.items()) + list(running.items()))
return ret
@ -2082,7 +2089,7 @@ class State(object):
'changes': {}}
try:
with salt.utils.fopen(ret_cache, 'rb') as fp_:
ret = msgpack.loads(fp_.read())
ret = msgpack_deserialize(fp_.read())
except (OSError, IOError):
ret = {'result': False,
'comment': 'Parallel cache failure',
@ -2176,15 +2183,17 @@ class State(object):
run_dict = self.pre
else:
run_dict = running
while True:
if self.reconcile_procs(run_dict):
break
time.sleep(0.0001)
for chunk in chunks:
tag = _gen_tag(chunk)
if tag not in run_dict:
fun_stats.add('unmet')
continue
if run_dict[tag].get('proc'):
# Run in parallel, first wait for a touch and then recheck
time.sleep(0.01)
return self.check_requisite(low, running, chunks, pre)
if r_state == 'onfail':
if run_dict[tag]['result'] is True:
fun_stats.add('onfail') # At least one state is OK

View File

@ -11,7 +11,6 @@ import re
import shutil
import subprocess
import time
import urllib
# Import salt libs
import salt.utils
@ -20,6 +19,8 @@ from salt.exceptions import CommandExecutionError, FileLockError, MinionError
# Import 3rd-party libs
from salt.ext import six
from salt.ext.six.moves.urllib.parse import quote # pylint: disable=no-name-in-module
log = logging.getLogger(__name__)
@ -312,7 +313,7 @@ def safe_filename_leaf(file_basename):
:codeauthor: Damon Atkins <https://github.com/damon-atkins>
'''
def _replace(re_obj):
return urllib.quote(re_obj.group(0), safe=u'')
return quote(re_obj.group(0), safe=u'')
if not isinstance(file_basename, six.text_type):
# the following string is not prefixed with u
return re.sub('[\\\\:/*?"<>|]',

View File

@ -6,10 +6,12 @@ Tests for the state runner
# Import Python Libs
from __future__ import absolute_import
import errno
import logging
import os
import shutil
import signal
import tempfile
import time
import textwrap
import yaml
import threading
@ -24,6 +26,8 @@ from tests.support.paths import TMP
import salt.utils
import salt.utils.event
log = logging.getLogger(__name__)
class StateRunnerTest(ShellCase):
'''
@ -275,3 +279,85 @@ class OrchEventTest(ShellCase):
finally:
del listener
signal.alarm(0)
def test_parallel_orchestrations(self):
'''
Test to confirm that the parallel state requisite works in orch
we do this by running 10 test.sleep's of 10 seconds, and insure it only takes roughly 10s
'''
self.write_conf({
'fileserver_backend': ['roots'],
'file_roots': {
'base': [self.base_env],
},
})
orch_sls = os.path.join(self.base_env, 'test_par_orch.sls')
with salt.utils.fopen(orch_sls, 'w') as fp_:
fp_.write(textwrap.dedent('''
{% for count in range(1, 20) %}
sleep {{ count }}:
module.run:
- name: test.sleep
- length: 10
- parallel: True
{% endfor %}
sleep 21:
module.run:
- name: test.sleep
- length: 10
- parallel: True
- require:
- module: sleep 1
'''))
orch_sls = os.path.join(self.base_env, 'test_par_orch.sls')
listener = salt.utils.event.get_event(
'master',
sock_dir=self.master_opts['sock_dir'],
transport=self.master_opts['transport'],
opts=self.master_opts)
start_time = time.time()
jid = self.run_run_plus(
'state.orchestrate',
'test_par_orch',
__reload_config=True).get('jid')
if jid is None:
raise Exception('jid missing from run_run_plus output')
signal.signal(signal.SIGALRM, self.alarm_handler)
signal.alarm(self.timeout)
received = False
try:
while True:
event = listener.get_event(full=True)
if event is None:
continue
# if we receive the ret for this job before self.timeout (60),
# the test is implicitly sucessful; if it were happening in serial it would be
# atleast 110 seconds.
if event['tag'] == 'salt/run/{0}/ret'.format(jid):
received = True
# Don't wrap this in a try/except. We want to know if the
# data structure is different from what we expect!
ret = event['data']['return']['data']['master']
for state in ret:
data = ret[state]
# we expect each duration to be greater than 10s
self.assertTrue(data['duration'] > 10000)
break
# self confirm that the total runtime is roughly 30s (left 10s for buffer)
self.assertTrue((time.time() - start_time) < 40)
finally:
self.assertTrue(received)
del listener
signal.alarm(0)