Refactor WaitOrKill into simplified StopOrKill

Use a default stop signal set by the GOOS.
This commit is contained in:
Zachary Wasserman 2020-12-16 16:59:58 -08:00
parent 6901819fdf
commit d562f82718
2 changed files with 32 additions and 34 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt"
"os"
"os/exec"
"runtime"
"time"
)
@ -47,15 +48,11 @@ func newWithMock(cmd ExecCmd) *Process {
//
// Adapted from Go core:
// https://github.com/golang/go/blob/8981092d71aee273d27b0e11cf932a34d4d365c1/src/cmd/go/script_test.go#L1131-L1190
func (p *Process) WaitOrKill(ctx context.Context, interrupt os.Signal, killDelay time.Duration) error {
func (p *Process) StopOrKill(ctx context.Context, killDelay time.Duration) error {
if p.OsProcess() == nil {
return fmt.Errorf("WaitOrKill requires a non-nil OsProcess - missing Start call?")
}
if interrupt == nil {
return fmt.Errorf("WaitOrKill requires a non-nil interrupt signal")
}
errc := make(chan error)
go func() {
select {
@ -64,7 +61,7 @@ func (p *Process) WaitOrKill(ctx context.Context, interrupt os.Signal, killDelay
case <-ctx.Done():
}
err := p.OsProcess().Signal(interrupt)
err := p.OsProcess().Signal(stopSignal())
if err == nil {
err = ctx.Err() // Report ctx.Err() as the reason we interrupted.
} else if err.Error() == "os: process already finished" {
@ -101,3 +98,18 @@ func (p *Process) WaitOrKill(ctx context.Context, interrupt os.Signal, killDelay
}
return waitErr
}
// stopSignal returns the appropriate signal to use to request that a process
// stop execution.
//
// Copied from Go core:
// https://github.com/golang/go/blob/8981092d71aee273d27b0e11cf932a34d4d365c1/src/cmd/go/script_test.go#L1119-L1129
func stopSignal() os.Signal {
if runtime.GOOS == "windows" {
// Per https://golang.org/pkg/os/#Signal, “Interrupt is not implemented on
// Windows; using it with os.Process.Signal will return an error.”
// Fall back to Kill instead.
return os.Kill
}
return os.Interrupt
}

View File

@ -3,7 +3,6 @@ package main
import (
"context"
"fmt"
"os"
"testing"
"time"
@ -19,25 +18,12 @@ func TestWaitOrKillNilProcess(t *testing.T) {
mockCmd.On("OsProcess").Return(nil)
p := newWithMock(mockCmd)
err := p.WaitOrKill(context.Background(), os.Interrupt, 1*time.Second)
err := p.StopOrKill(context.Background(), 1*time.Second)
require.Error(t, err)
assert.Contains(t, err.Error(), "non-nil OsProcess")
}
func TestWaitOrKillNilSignal(t *testing.T) {
// Nil signal provided
mockCmd := &mockExecCmd{}
mockProcess := &mockOsProcess{}
defer mock.AssertExpectationsForObjects(t, mockCmd, mockProcess)
mockCmd.On("OsProcess").Return(mockProcess)
p := newWithMock(mockCmd)
err := p.WaitOrKill(context.Background(), nil, 10*time.Millisecond)
require.Error(t, err)
assert.Contains(t, err.Error(), "non-nil interrupt")
}
func TestWaitOrKillProcessCompleted(t *testing.T) {
func TestStopOrKillProcessCompleted(t *testing.T) {
// Process already completed
mockCmd := &mockExecCmd{}
mockProcess := &mockOsProcess{}
@ -46,11 +32,11 @@ func TestWaitOrKillProcessCompleted(t *testing.T) {
mockCmd.On("Wait").Return(nil)
p := newWithMock(mockCmd)
err := p.WaitOrKill(context.Background(), os.Interrupt, 10*time.Millisecond)
err := p.StopOrKill(context.Background(), 10*time.Millisecond)
require.NoError(t, err)
}
func TestWaitOrKillProcessCompletedError(t *testing.T) {
func TestStopOrKillProcessCompletedError(t *testing.T) {
// Process already completed with error
mockCmd := &mockExecCmd{}
mockProcess := &mockOsProcess{}
@ -59,58 +45,58 @@ func TestWaitOrKillProcessCompletedError(t *testing.T) {
mockCmd.On("Wait").After(10 * time.Millisecond).Return(fmt.Errorf("super bad"))
p := newWithMock(mockCmd)
err := p.WaitOrKill(context.Background(), os.Interrupt, 10*time.Millisecond)
err := p.StopOrKill(context.Background(), 10*time.Millisecond)
require.Error(t, err)
assert.Contains(t, err.Error(), "super bad")
}
func TestWaitOrKillWait(t *testing.T) {
func TestStopOrKillWait(t *testing.T) {
// Process completes after the wait call and after the signal is sent
mockCmd := &mockExecCmd{}
mockProcess := &mockOsProcess{}
defer mock.AssertExpectationsForObjects(t, mockCmd, mockProcess)
mockCmd.On("OsProcess").Return(mockProcess)
mockCmd.On("Wait").After(5 * time.Millisecond).Return(nil)
mockProcess.On("Signal", os.Interrupt).Return(nil)
mockProcess.On("Signal", stopSignal()).Return(nil)
p := newWithMock(mockCmd)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := p.WaitOrKill(ctx, os.Interrupt, 10*time.Millisecond)
err := p.StopOrKill(ctx, 10*time.Millisecond)
require.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
}
func TestWaitOrKillWaitSignalCompleted(t *testing.T) {
func TestStopOrKillWaitSignalCompleted(t *testing.T) {
// Process completes after the wait call and before the signal is sent
mockCmd := &mockExecCmd{}
mockProcess := &mockOsProcess{}
defer mock.AssertExpectationsForObjects(t, mockCmd, mockProcess)
mockCmd.On("OsProcess").Return(mockProcess)
mockCmd.On("Wait").After(10 * time.Millisecond).Return(nil)
mockProcess.On("Signal", os.Interrupt).Return(fmt.Errorf("os: process already finished"))
mockProcess.On("Signal", stopSignal()).Return(fmt.Errorf("os: process already finished"))
p := newWithMock(mockCmd)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := p.WaitOrKill(ctx, os.Interrupt, 5*time.Millisecond)
err := p.StopOrKill(ctx, 5*time.Millisecond)
require.NoError(t, err)
}
func TestWaitOrKillWaitKilled(t *testing.T) {
func TestStopOrKillWaitKilled(t *testing.T) {
// Process is killed after the wait call and signal
mockCmd := &mockExecCmd{}
mockProcess := &mockOsProcess{}
defer mock.AssertExpectationsForObjects(t, mockCmd, mockProcess)
mockCmd.On("OsProcess").Return(mockProcess)
mockCmd.On("Wait").After(10 * time.Millisecond).Return(fmt.Errorf("killed"))
mockProcess.On("Signal", os.Interrupt).Return(nil)
mockProcess.On("Signal", stopSignal()).Return(nil)
mockProcess.On("Kill").Return(nil)
p := newWithMock(mockCmd)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := p.WaitOrKill(ctx, os.Interrupt, 5*time.Millisecond)
err := p.StopOrKill(ctx, 5*time.Millisecond)
require.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
}