mirror of
https://github.com/valitydev/salt.git
synced 2024-11-06 16:45:27 +00:00
Fix signal handling in subprocesses
This commit is contained in:
parent
a13cb3eae6
commit
1d83b10025
@ -669,22 +669,9 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
|
||||
return instance
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if (salt.utils.platform.is_windows() and
|
||||
not hasattr(self, '_is_child') and
|
||||
self.__setstate__.__code__ is
|
||||
MultiprocessingProcess.__setstate__.__code__):
|
||||
# On Windows, if a derived class hasn't defined __setstate__, that
|
||||
# means the 'MultiprocessingProcess' version will be used. For this
|
||||
# version, save a copy of the args and kwargs to use with its
|
||||
# __setstate__ and __getstate__.
|
||||
# We do this so that __init__ will be invoked on Windows in the
|
||||
# child process so that a register_after_fork() equivalent will
|
||||
# work on Windows. Note that this will only work if the derived
|
||||
# class uses the exact same args and kwargs as this class. Hence
|
||||
# this will also work for 'SignalHandlingMultiprocessingProcess'.
|
||||
# However, many derived classes take params that they don't pass
|
||||
# down (eg opts). Those classes need to override __setstate__ and
|
||||
# __getstate__ themselves.
|
||||
if salt.utils.platform.is_windows():
|
||||
# On Windows, subclasses should call super if they define
|
||||
# __setstate__ and/or __getstate__
|
||||
self._args_for_getstate = copy.copy(args)
|
||||
self._kwargs_for_getstate = copy.copy(kwargs)
|
||||
|
||||
@ -706,42 +693,21 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
|
||||
# 'log_queue' and 'log_queue_level' from kwargs.
|
||||
super(MultiprocessingProcess, self).__init__(*args, **kwargs)
|
||||
|
||||
if salt.utils.platform.is_windows():
|
||||
# On Windows, the multiprocessing.Process object is reinitialized
|
||||
# in the child process via the constructor. Due to this, methods
|
||||
# such as ident() and is_alive() won't work properly. So we use
|
||||
# our own creation '_is_child' for this purpose.
|
||||
if hasattr(self, '_is_child'):
|
||||
# On Windows, no need to call register_after_fork().
|
||||
# register_after_fork() would only work on Windows if called
|
||||
# from the child process anyway. Since we know this is the
|
||||
# child process, call __setup_process_logging() directly.
|
||||
self.__setup_process_logging()
|
||||
multiprocessing.util.Finalize(
|
||||
self,
|
||||
salt.log.setup.shutdown_multiprocessing_logging,
|
||||
exitpriority=16
|
||||
)
|
||||
else:
|
||||
multiprocessing.util.register_after_fork(
|
||||
self,
|
||||
MultiprocessingProcess.__setup_process_logging
|
||||
)
|
||||
multiprocessing.util.Finalize(
|
||||
self,
|
||||
salt.log.setup.shutdown_multiprocessing_logging,
|
||||
exitpriority=16
|
||||
)
|
||||
self._after_fork_methods = [
|
||||
(MultiprocessingProcess._setup_process_logging, [self], {}),
|
||||
]
|
||||
self._finalize_methods = [
|
||||
(salt.log.setup.shutdown_multiprocessing_logging, [], {})
|
||||
]
|
||||
|
||||
# __setstate__ and __getstate__ are only used on Windows.
|
||||
# We do this so that __init__ will be invoked on Windows in the child
|
||||
# process so that a register_after_fork() equivalent will work on Windows.
|
||||
def __setstate__(self, state):
|
||||
self._is_child = True
|
||||
args = state['args']
|
||||
kwargs = state['kwargs']
|
||||
# This will invoke __init__ of the most derived class.
|
||||
self.__init__(*args, **kwargs)
|
||||
self._after_fork_methods = self._after_fork_methods
|
||||
self._finalize_methods = self._finalize_methods
|
||||
|
||||
def __getstate__(self):
|
||||
args = self._args_for_getstate
|
||||
@ -755,12 +721,17 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
|
||||
del self._args_for_getstate
|
||||
del self._kwargs_for_getstate
|
||||
return {'args': args,
|
||||
'kwargs': kwargs}
|
||||
'kwargs': kwargs,
|
||||
'_after_fork_methods': self._after_fork_methods,
|
||||
'_finalize_methods': self._finalize_methods,
|
||||
}
|
||||
|
||||
def __setup_process_logging(self):
|
||||
def _setup_process_logging(self):
|
||||
salt.log.setup.setup_multiprocessing_logging(self.log_queue)
|
||||
|
||||
def _run(self):
|
||||
for method, args, kwargs in self._after_fork_methods:
|
||||
method(*args, **kwargs)
|
||||
try:
|
||||
return self._original_run()
|
||||
except SystemExit:
|
||||
@ -774,31 +745,28 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
|
||||
# sys.stderr and set the proper exitcode and we have already logged
|
||||
# it above.
|
||||
raise
|
||||
finally:
|
||||
for method, args, kwargs in self._finalize_methods:
|
||||
method(*args, **kwargs)
|
||||
|
||||
|
||||
class SignalHandlingMultiprocessingProcess(MultiprocessingProcess):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SignalHandlingMultiprocessingProcess, self).__init__(*args, **kwargs)
|
||||
if salt.utils.platform.is_windows():
|
||||
if hasattr(self, '_is_child'):
|
||||
# On Windows, no need to call register_after_fork().
|
||||
# register_after_fork() would only work on Windows if called
|
||||
# from the child process anyway. Since we know this is the
|
||||
# child process, call __setup_signals() directly.
|
||||
self.__setup_signals()
|
||||
else:
|
||||
multiprocessing.util.register_after_fork(
|
||||
self,
|
||||
SignalHandlingMultiprocessingProcess.__setup_signals
|
||||
)
|
||||
self._signal_handled = multiprocessing.Event()
|
||||
self._after_fork_methods.append(
|
||||
(SignalHandlingMultiprocessingProcess._setup_signals, [self], {})
|
||||
)
|
||||
|
||||
def __setup_signals(self):
|
||||
def signal_handled(self):
|
||||
return self._signal_handled.is_set()
|
||||
|
||||
def _setup_signals(self):
|
||||
signal.signal(signal.SIGINT, self._handle_signals)
|
||||
signal.signal(signal.SIGTERM, self._handle_signals)
|
||||
|
||||
def _handle_signals(self, signum, sigframe):
|
||||
signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
self._signal_handled.set()
|
||||
msg = '{0} received a '.format(self.__class__.__name__)
|
||||
if signum == signal.SIGINT:
|
||||
msg += 'SIGINT'
|
||||
@ -808,7 +776,7 @@ class SignalHandlingMultiprocessingProcess(MultiprocessingProcess):
|
||||
log.debug(msg)
|
||||
if HAS_PSUTIL:
|
||||
try:
|
||||
process = psutil.Process(self.pid)
|
||||
process = psutil.Process(os.getpid())
|
||||
if hasattr(process, 'children'):
|
||||
for child in process.children(recursive=True):
|
||||
try:
|
||||
|
@ -236,35 +236,166 @@ class TestProcess(TestCase):
|
||||
# pylint: enable=assignment-from-none
|
||||
|
||||
|
||||
@skipIf(sys.platform.startswith('win'), 'pickling nested function errors on Windows')
|
||||
class TestSignalHandlingMultiprocessingProcess(TestCase):
|
||||
|
||||
@classmethod
|
||||
def Process(cls, pid):
|
||||
raise psutil.NoSuchProcess(pid)
|
||||
|
||||
@classmethod
|
||||
def target(cls):
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
@classmethod
|
||||
def children(cls, *args, **kwargs):
|
||||
raise psutil.NoSuchProcess(1)
|
||||
|
||||
@skipIf(NO_MOCK, NO_MOCK_REASON)
|
||||
def test_process_does_not_exist(self):
|
||||
def Process(pid):
|
||||
raise psutil.NoSuchProcess(pid)
|
||||
|
||||
def target():
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
try:
|
||||
with patch('psutil.Process', Process):
|
||||
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=target)
|
||||
with patch('psutil.Process', self.Process):
|
||||
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.target)
|
||||
proc.start()
|
||||
except psutil.NoSuchProcess:
|
||||
assert False, "psutil.NoSuchProcess raised"
|
||||
|
||||
@skipIf(NO_MOCK, NO_MOCK_REASON)
|
||||
def test_process_children_do_not_exist(self):
|
||||
def children(*args, **kwargs):
|
||||
raise psutil.NoSuchProcess(1)
|
||||
|
||||
def target():
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
try:
|
||||
with patch('psutil.Process.children', children):
|
||||
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=target)
|
||||
with patch('psutil.Process.children', self.children):
|
||||
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.target)
|
||||
proc.start()
|
||||
except psutil.NoSuchProcess:
|
||||
assert False, "psutil.NoSuchProcess raised"
|
||||
|
||||
@staticmethod
|
||||
def run_forever_sub_target(evt):
|
||||
'Used by run_forever_target to create a sub-process'
|
||||
while not evt.is_set():
|
||||
time.sleep(1)
|
||||
|
||||
@staticmethod
|
||||
def run_forever_target(sub_target, evt):
|
||||
'A target that will run forever or until an event is set'
|
||||
p = multiprocessing.Process(target=sub_target, args=(evt,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
@staticmethod
|
||||
def kill_target_sub_proc():
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
return
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
return
|
||||
time.sleep(.1)
|
||||
try:
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
@skipIf(sys.platform.startswith('win'), 'No os.fork on Windows')
|
||||
def test_signal_processing_regression_test(self):
|
||||
evt = multiprocessing.Event()
|
||||
sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(
|
||||
target=self.run_forever_target,
|
||||
args=(self.run_forever_sub_target, evt)
|
||||
)
|
||||
sh_proc.start()
|
||||
proc = multiprocessing.Process(target=self.kill_target_sub_proc)
|
||||
proc.start()
|
||||
proc.join()
|
||||
# When the bug exists, the kill_target_sub_proc signal will kill both
|
||||
# processes. sh_proc will be alive if the bug is fixed
|
||||
try:
|
||||
assert sh_proc.is_alive()
|
||||
finally:
|
||||
evt.set()
|
||||
sh_proc.join()
|
||||
|
||||
@staticmethod
|
||||
def no_op_target():
|
||||
pass
|
||||
|
||||
@skipIf(NO_MOCK, NO_MOCK_REASON)
|
||||
def test_signal_processing_test_after_fork_called(self):
|
||||
'Validate MultiprocessingProcess and sub classes call after fork methods'
|
||||
evt = multiprocessing.Event()
|
||||
sig_to_mock = 'salt.utils.process.SignalHandlingMultiprocessingProcess._setup_signals'
|
||||
log_to_mock = 'salt.utils.process.MultiprocessingProcess._setup_process_logging'
|
||||
with patch(sig_to_mock) as ma, patch(log_to_mock) as mb:
|
||||
self.sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.no_op_target)
|
||||
self.sh_proc._run()
|
||||
ma.assert_called()
|
||||
mb.assert_called()
|
||||
|
||||
@skipIf(NO_MOCK, NO_MOCK_REASON)
|
||||
def test_signal_processing_test_final_methods_called(self):
|
||||
'Validate MultiprocessingProcess and sub classes call finalize methods'
|
||||
evt = multiprocessing.Event()
|
||||
teardown_to_mock = 'salt.log.setup.shutdown_multiprocessing_logging'
|
||||
log_to_mock = 'salt.utils.process.MultiprocessingProcess._setup_process_logging'
|
||||
sig_to_mock = 'salt.utils.process.SignalHandlingMultiprocessingProcess._setup_signals'
|
||||
# Mock _setup_signals so we do not register one for this process.
|
||||
with patch(sig_to_mock):
|
||||
with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb:
|
||||
self.sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.no_op_target)
|
||||
self.sh_proc._run()
|
||||
ma.assert_called()
|
||||
mb.assert_called()
|
||||
|
||||
@staticmethod
|
||||
def pid_setting_target(sub_target, val, evt):
|
||||
val.value = os.getpid()
|
||||
p = multiprocessing.Process(target=sub_target, args=(evt,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
@skipIf(sys.platform.startswith('win'), 'Required signals not supported on windows')
|
||||
def test_signal_processing_handle_signals_called(self):
|
||||
'Validate SignalHandlingMultiprocessingProcess handles signals'
|
||||
# Gloobal event to stop all processes we're creating
|
||||
evt = multiprocessing.Event()
|
||||
|
||||
# Create a process to test signal handler
|
||||
val = multiprocessing.Value('i', 0)
|
||||
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(
|
||||
target=self.pid_setting_target,
|
||||
args=(self.run_forever_sub_target, val, evt),
|
||||
)
|
||||
proc.start()
|
||||
|
||||
# Create a second process that should not respond to SIGINT or SIGTERM
|
||||
proc2 = multiprocessing.Process(
|
||||
target=self.run_forever_target,
|
||||
args=(self.run_forever_sub_target, evt),
|
||||
)
|
||||
proc2.start()
|
||||
|
||||
# Wait for the sub process to set it's pid
|
||||
while not val.value:
|
||||
time.sleep(.3)
|
||||
|
||||
assert not proc.signal_handled()
|
||||
|
||||
# Send a signal that should get handled by the subprocess
|
||||
os.kill(val.value, signal.SIGTERM)
|
||||
|
||||
# wait up to 10 seconds for signal handler:
|
||||
start = time.time()
|
||||
while time.time() - start < 10:
|
||||
if proc.signal_handled():
|
||||
break
|
||||
time.sleep(.3)
|
||||
|
||||
try:
|
||||
# Allow some time for the signal handler to do it's thing
|
||||
assert proc.signal_handled()
|
||||
# Reap the signaled process
|
||||
proc.join(1)
|
||||
assert proc2.is_alive()
|
||||
finally:
|
||||
evt.set()
|
||||
proc2.join(30)
|
||||
proc.join(30)
|
||||
|
Loading…
Reference in New Issue
Block a user