Fix signal handling in subprocesses

This commit is contained in:
Daniel A. Wozniak 2019-08-13 14:28:39 -07:00
parent a13cb3eae6
commit 1d83b10025
No known key found for this signature in database
GPG Key ID: 166B9D2C06C82D61
2 changed files with 179 additions and 80 deletions

View File

@ -669,22 +669,9 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
return instance
def __init__(self, *args, **kwargs):
if (salt.utils.platform.is_windows() and
not hasattr(self, '_is_child') and
self.__setstate__.__code__ is
MultiprocessingProcess.__setstate__.__code__):
# On Windows, if a derived class hasn't defined __setstate__, that
# means the 'MultiprocessingProcess' version will be used. For this
# version, save a copy of the args and kwargs to use with its
# __setstate__ and __getstate__.
# We do this so that __init__ will be invoked on Windows in the
# child process so that a register_after_fork() equivalent will
# work on Windows. Note that this will only work if the derived
# class uses the exact same args and kwargs as this class. Hence
# this will also work for 'SignalHandlingMultiprocessingProcess'.
# However, many derived classes take params that they don't pass
# down (eg opts). Those classes need to override __setstate__ and
# __getstate__ themselves.
if salt.utils.platform.is_windows():
# On Windows, subclasses should call super if they define
# __setstate__ and/or __getstate__
self._args_for_getstate = copy.copy(args)
self._kwargs_for_getstate = copy.copy(kwargs)
@ -706,42 +693,21 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
# 'log_queue' and 'log_queue_level' from kwargs.
super(MultiprocessingProcess, self).__init__(*args, **kwargs)
if salt.utils.platform.is_windows():
# On Windows, the multiprocessing.Process object is reinitialized
# in the child process via the constructor. Due to this, methods
# such as ident() and is_alive() won't work properly. So we use
# our own creation '_is_child' for this purpose.
if hasattr(self, '_is_child'):
# On Windows, no need to call register_after_fork().
# register_after_fork() would only work on Windows if called
# from the child process anyway. Since we know this is the
# child process, call __setup_process_logging() directly.
self.__setup_process_logging()
multiprocessing.util.Finalize(
self,
salt.log.setup.shutdown_multiprocessing_logging,
exitpriority=16
)
else:
multiprocessing.util.register_after_fork(
self,
MultiprocessingProcess.__setup_process_logging
)
multiprocessing.util.Finalize(
self,
salt.log.setup.shutdown_multiprocessing_logging,
exitpriority=16
)
self._after_fork_methods = [
(MultiprocessingProcess._setup_process_logging, [self], {}),
]
self._finalize_methods = [
(salt.log.setup.shutdown_multiprocessing_logging, [], {})
]
# __setstate__ and __getstate__ are only used on Windows.
# We do this so that __init__ will be invoked on Windows in the child
# process so that a register_after_fork() equivalent will work on Windows.
def __setstate__(self, state):
self._is_child = True
args = state['args']
kwargs = state['kwargs']
# This will invoke __init__ of the most derived class.
self.__init__(*args, **kwargs)
self._after_fork_methods = self._after_fork_methods
self._finalize_methods = self._finalize_methods
def __getstate__(self):
args = self._args_for_getstate
@ -755,12 +721,17 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
del self._args_for_getstate
del self._kwargs_for_getstate
return {'args': args,
'kwargs': kwargs}
'kwargs': kwargs,
'_after_fork_methods': self._after_fork_methods,
'_finalize_methods': self._finalize_methods,
}
def __setup_process_logging(self):
def _setup_process_logging(self):
salt.log.setup.setup_multiprocessing_logging(self.log_queue)
def _run(self):
for method, args, kwargs in self._after_fork_methods:
method(*args, **kwargs)
try:
return self._original_run()
except SystemExit:
@ -774,31 +745,28 @@ class MultiprocessingProcess(multiprocessing.Process, NewStyleClassMixIn):
# sys.stderr and set the proper exitcode and we have already logged
# it above.
raise
finally:
for method, args, kwargs in self._finalize_methods:
method(*args, **kwargs)
class SignalHandlingMultiprocessingProcess(MultiprocessingProcess):
def __init__(self, *args, **kwargs):
super(SignalHandlingMultiprocessingProcess, self).__init__(*args, **kwargs)
if salt.utils.platform.is_windows():
if hasattr(self, '_is_child'):
# On Windows, no need to call register_after_fork().
# register_after_fork() would only work on Windows if called
# from the child process anyway. Since we know this is the
# child process, call __setup_signals() directly.
self.__setup_signals()
else:
multiprocessing.util.register_after_fork(
self,
SignalHandlingMultiprocessingProcess.__setup_signals
)
self._signal_handled = multiprocessing.Event()
self._after_fork_methods.append(
(SignalHandlingMultiprocessingProcess._setup_signals, [self], {})
)
def __setup_signals(self):
def signal_handled(self):
return self._signal_handled.is_set()
def _setup_signals(self):
signal.signal(signal.SIGINT, self._handle_signals)
signal.signal(signal.SIGTERM, self._handle_signals)
def _handle_signals(self, signum, sigframe):
signal.signal(signal.SIGTERM, signal.SIG_IGN)
signal.signal(signal.SIGINT, signal.SIG_IGN)
self._signal_handled.set()
msg = '{0} received a '.format(self.__class__.__name__)
if signum == signal.SIGINT:
msg += 'SIGINT'
@ -808,7 +776,7 @@ class SignalHandlingMultiprocessingProcess(MultiprocessingProcess):
log.debug(msg)
if HAS_PSUTIL:
try:
process = psutil.Process(self.pid)
process = psutil.Process(os.getpid())
if hasattr(process, 'children'):
for child in process.children(recursive=True):
try:

View File

@ -236,35 +236,166 @@ class TestProcess(TestCase):
# pylint: enable=assignment-from-none
@skipIf(sys.platform.startswith('win'), 'pickling nested function errors on Windows')
class TestSignalHandlingMultiprocessingProcess(TestCase):
@classmethod
def Process(cls, pid):
raise psutil.NoSuchProcess(pid)
@classmethod
def target(cls):
os.kill(os.getpid(), signal.SIGTERM)
@classmethod
def children(cls, *args, **kwargs):
raise psutil.NoSuchProcess(1)
@skipIf(NO_MOCK, NO_MOCK_REASON)
def test_process_does_not_exist(self):
def Process(pid):
raise psutil.NoSuchProcess(pid)
def target():
os.kill(os.getpid(), signal.SIGTERM)
try:
with patch('psutil.Process', Process):
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=target)
with patch('psutil.Process', self.Process):
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.target)
proc.start()
except psutil.NoSuchProcess:
assert False, "psutil.NoSuchProcess raised"
@skipIf(NO_MOCK, NO_MOCK_REASON)
def test_process_children_do_not_exist(self):
def children(*args, **kwargs):
raise psutil.NoSuchProcess(1)
def target():
os.kill(os.getpid(), signal.SIGTERM)
try:
with patch('psutil.Process.children', children):
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=target)
with patch('psutil.Process.children', self.children):
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.target)
proc.start()
except psutil.NoSuchProcess:
assert False, "psutil.NoSuchProcess raised"
@staticmethod
def run_forever_sub_target(evt):
'Used by run_forever_target to create a sub-process'
while not evt.is_set():
time.sleep(1)
@staticmethod
def run_forever_target(sub_target, evt):
'A target that will run forever or until an event is set'
p = multiprocessing.Process(target=sub_target, args=(evt,))
p.start()
p.join()
@staticmethod
def kill_target_sub_proc():
pid = os.fork()
if pid == 0:
return
pid = os.fork()
if pid == 0:
return
time.sleep(.1)
try:
os.kill(os.getpid(), signal.SIGINT)
except KeyboardInterrupt:
pass
@skipIf(sys.platform.startswith('win'), 'No os.fork on Windows')
def test_signal_processing_regression_test(self):
evt = multiprocessing.Event()
sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(
target=self.run_forever_target,
args=(self.run_forever_sub_target, evt)
)
sh_proc.start()
proc = multiprocessing.Process(target=self.kill_target_sub_proc)
proc.start()
proc.join()
# When the bug exists, the kill_target_sub_proc signal will kill both
# processes. sh_proc will be alive if the bug is fixed
try:
assert sh_proc.is_alive()
finally:
evt.set()
sh_proc.join()
@staticmethod
def no_op_target():
pass
@skipIf(NO_MOCK, NO_MOCK_REASON)
def test_signal_processing_test_after_fork_called(self):
'Validate MultiprocessingProcess and sub classes call after fork methods'
evt = multiprocessing.Event()
sig_to_mock = 'salt.utils.process.SignalHandlingMultiprocessingProcess._setup_signals'
log_to_mock = 'salt.utils.process.MultiprocessingProcess._setup_process_logging'
with patch(sig_to_mock) as ma, patch(log_to_mock) as mb:
self.sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.no_op_target)
self.sh_proc._run()
ma.assert_called()
mb.assert_called()
@skipIf(NO_MOCK, NO_MOCK_REASON)
def test_signal_processing_test_final_methods_called(self):
'Validate MultiprocessingProcess and sub classes call finalize methods'
evt = multiprocessing.Event()
teardown_to_mock = 'salt.log.setup.shutdown_multiprocessing_logging'
log_to_mock = 'salt.utils.process.MultiprocessingProcess._setup_process_logging'
sig_to_mock = 'salt.utils.process.SignalHandlingMultiprocessingProcess._setup_signals'
# Mock _setup_signals so we do not register one for this process.
with patch(sig_to_mock):
with patch(teardown_to_mock) as ma, patch(log_to_mock) as mb:
self.sh_proc = salt.utils.process.SignalHandlingMultiprocessingProcess(target=self.no_op_target)
self.sh_proc._run()
ma.assert_called()
mb.assert_called()
@staticmethod
def pid_setting_target(sub_target, val, evt):
val.value = os.getpid()
p = multiprocessing.Process(target=sub_target, args=(evt,))
p.start()
p.join()
@skipIf(sys.platform.startswith('win'), 'Required signals not supported on windows')
def test_signal_processing_handle_signals_called(self):
'Validate SignalHandlingMultiprocessingProcess handles signals'
# Gloobal event to stop all processes we're creating
evt = multiprocessing.Event()
# Create a process to test signal handler
val = multiprocessing.Value('i', 0)
proc = salt.utils.process.SignalHandlingMultiprocessingProcess(
target=self.pid_setting_target,
args=(self.run_forever_sub_target, val, evt),
)
proc.start()
# Create a second process that should not respond to SIGINT or SIGTERM
proc2 = multiprocessing.Process(
target=self.run_forever_target,
args=(self.run_forever_sub_target, evt),
)
proc2.start()
# Wait for the sub process to set it's pid
while not val.value:
time.sleep(.3)
assert not proc.signal_handled()
# Send a signal that should get handled by the subprocess
os.kill(val.value, signal.SIGTERM)
# wait up to 10 seconds for signal handler:
start = time.time()
while time.time() - start < 10:
if proc.signal_handled():
break
time.sleep(.3)
try:
# Allow some time for the signal handler to do it's thing
assert proc.signal_handled()
# Reap the signaled process
proc.join(1)
assert proc2.is_alive()
finally:
evt.set()
proc2.join(30)
proc.join(30)