mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
384 lines
12 KiB
Go
384 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/fleetdm/fleet/v4/server/service"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestRunScriptCommand(t *testing.T) {
|
|
_, ds := runServerWithMockedDS(t,
|
|
&service.TestServerOpts{
|
|
License: &fleet.LicenseInfo{
|
|
Tier: fleet.TierPremium,
|
|
},
|
|
},
|
|
&service.TestServerOpts{
|
|
HTTPServerConfig: &http.Server{WriteTimeout: 90 * time.Second}, // nolint:gosec
|
|
},
|
|
)
|
|
|
|
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
|
|
return nil
|
|
}
|
|
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
|
|
return nil, nil
|
|
}
|
|
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
|
|
return nil, nil
|
|
}
|
|
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
|
|
return nil, nil
|
|
}
|
|
ds.ListHostBatteriesFunc = func(ctx context.Context, hid uint) ([]*fleet.HostBattery, error) {
|
|
return nil, nil
|
|
}
|
|
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
|
return &fleet.AppConfig{ServerSettings: fleet.ServerSettings{ScriptsDisabled: false}}, nil
|
|
}
|
|
ds.GetScriptIDByNameFunc = func(ctx context.Context, name string, teamID *uint) (uint, error) {
|
|
return 1, nil
|
|
}
|
|
ds.IsExecutionPendingForHostFunc = func(ctx context.Context, hid uint, scriptID uint) ([]*uint, error) {
|
|
return []*uint{}, nil
|
|
}
|
|
|
|
generateValidPath := func() string {
|
|
return writeTmpScriptContents(t, "echo hello world", ".sh")
|
|
}
|
|
exceedsMaxCharsUnsaved := strings.Repeat("a", fleet.UnsavedScriptMaxRuneLen+1)
|
|
exceedsMaxCharsSaved := strings.Repeat("a", fleet.SavedScriptMaxRuneLen+1)
|
|
|
|
expectedOutputSuccess := `
|
|
Exit code: 0 (Script ran successfully.)
|
|
|
|
Output:
|
|
|
|
-------------------------------------------------------------------------------------
|
|
|
|
hello world
|
|
|
|
-------------------------------------------------------------------------------------
|
|
`
|
|
|
|
type testCase struct {
|
|
name string
|
|
scriptPath func() string
|
|
scriptName string
|
|
teamID *uint
|
|
savedScriptContents func() ([]byte, error)
|
|
scriptResult *fleet.HostScriptResult
|
|
expectOutput string
|
|
expectErrMsg string
|
|
expectNotFound bool
|
|
expectOffline bool
|
|
expectPending bool
|
|
}
|
|
|
|
cases := []testCase{
|
|
{
|
|
name: "host offline",
|
|
scriptPath: generateValidPath,
|
|
expectErrMsg: fleet.RunScriptHostOfflineErrMsg,
|
|
expectOffline: true,
|
|
},
|
|
{
|
|
name: "host not found",
|
|
scriptPath: generateValidPath,
|
|
expectErrMsg: fleet.RunScriptHostNotFoundErrMsg,
|
|
expectNotFound: true,
|
|
},
|
|
{
|
|
name: "invalid file type",
|
|
scriptPath: func() string { return writeTmpScriptContents(t, "echo hello world", ".txt") },
|
|
expectErrMsg: fleet.RunScriptInvalidTypeErrMsg,
|
|
},
|
|
{
|
|
name: "invalid hashbang",
|
|
scriptPath: func() string { return writeTmpScriptContents(t, "#! /foo/bar", ".sh") },
|
|
expectErrMsg: `Interpreter not supported. Bash scripts must run in "#!/bin/sh”.`,
|
|
},
|
|
{
|
|
name: "script too long (unsaved)",
|
|
scriptPath: func() string {
|
|
return writeTmpScriptContents(t, exceedsMaxCharsUnsaved, ".sh")
|
|
},
|
|
expectErrMsg: "Script is too large. Script referenced by '--script-path' is limited to 10,000 characters. To run larger script save it to Fleet and use '--script-name'.",
|
|
},
|
|
{
|
|
name: "script not too long (unsaved)",
|
|
scriptPath: func() string {
|
|
return writeTmpScriptContents(t, exceedsMaxCharsUnsaved[:fleet.UnsavedScriptMaxRuneLen], ".sh")
|
|
},
|
|
scriptResult: &fleet.HostScriptResult{
|
|
ExitCode: ptr.Int64(0),
|
|
Output: "hello world",
|
|
},
|
|
expectOutput: expectedOutputSuccess,
|
|
},
|
|
{
|
|
name: "script too long (saved)",
|
|
scriptName: "foo",
|
|
savedScriptContents: func() ([]byte, error) {
|
|
return []byte(exceedsMaxCharsSaved), nil
|
|
},
|
|
expectErrMsg: "Script is too large. It's limited to 500,000 characters (approximately 10,000 lines).",
|
|
},
|
|
{
|
|
name: "script not too long (saved)",
|
|
scriptName: "foo",
|
|
savedScriptContents: func() ([]byte, error) {
|
|
return []byte(exceedsMaxCharsUnsaved), nil
|
|
},
|
|
scriptResult: &fleet.HostScriptResult{
|
|
ExitCode: ptr.Int64(0),
|
|
Output: "hello world",
|
|
},
|
|
expectOutput: expectedOutputSuccess,
|
|
},
|
|
{
|
|
name: "script-path and script-name disallowed",
|
|
scriptPath: generateValidPath,
|
|
scriptName: "foo",
|
|
expectErrMsg: `Only one of '--script-path' or '--script-name' is allowed.`,
|
|
},
|
|
{
|
|
name: "missing one of script-path and script-nqme",
|
|
expectErrMsg: `One of '--script-path' or '--script-name' must be specified.`,
|
|
},
|
|
{
|
|
name: "script-path and team disallowed",
|
|
scriptPath: generateValidPath,
|
|
teamID: ptr.Uint(1),
|
|
expectErrMsg: `Only one of '--script-path' or '--team' is allowed.`,
|
|
},
|
|
{
|
|
name: "script empty",
|
|
scriptPath: func() string { return writeTmpScriptContents(t, "", ".sh") },
|
|
expectErrMsg: `Script contents must not be empty.`,
|
|
},
|
|
{
|
|
name: "invalid utf8",
|
|
scriptPath: func() string { return writeTmpScriptContents(t, "\xff\xfa", ".sh") },
|
|
expectErrMsg: `Wrong data format. Only plain text allowed.`,
|
|
},
|
|
{
|
|
name: "script already running",
|
|
scriptPath: generateValidPath,
|
|
expectErrMsg: fleet.RunScriptAlreadyRunningErrMsg,
|
|
expectPending: true,
|
|
},
|
|
{
|
|
name: "script successful",
|
|
scriptPath: generateValidPath,
|
|
scriptResult: &fleet.HostScriptResult{
|
|
ExitCode: ptr.Int64(0),
|
|
Output: "hello world",
|
|
},
|
|
expectOutput: expectedOutputSuccess,
|
|
},
|
|
{
|
|
name: "script failed",
|
|
scriptPath: generateValidPath,
|
|
scriptResult: &fleet.HostScriptResult{
|
|
ExitCode: ptr.Int64(1),
|
|
Output: "",
|
|
},
|
|
expectOutput: `
|
|
Exit code: 1 (Script failed.)
|
|
|
|
Output:
|
|
|
|
-------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
-------------------------------------------------------------------------------------
|
|
`,
|
|
},
|
|
{
|
|
name: "script killed",
|
|
scriptPath: generateValidPath,
|
|
scriptResult: &fleet.HostScriptResult{
|
|
ExitCode: ptr.Int64(-1),
|
|
Output: "Oh no!",
|
|
Message: fleet.RunScriptScriptTimeoutErrMsg,
|
|
},
|
|
expectOutput: `
|
|
Error: Timeout. Fleet stopped the script after 5 minutes to protect host performance.
|
|
|
|
Output before timeout:
|
|
|
|
-------------------------------------------------------------------------------------
|
|
|
|
Oh no!
|
|
|
|
-------------------------------------------------------------------------------------
|
|
`,
|
|
},
|
|
{
|
|
name: "scripts disabled",
|
|
scriptPath: generateValidPath,
|
|
scriptResult: &fleet.HostScriptResult{
|
|
ExitCode: ptr.Int64(-2),
|
|
Output: "",
|
|
Message: fleet.RunScriptDisabledErrMsg,
|
|
},
|
|
expectOutput: `
|
|
Error: Scripts are disabled for this host. To run scripts, deploy the fleetd agent with scripts enabled.
|
|
|
|
`,
|
|
},
|
|
{
|
|
name: "output truncated",
|
|
scriptPath: generateValidPath,
|
|
scriptResult: &fleet.HostScriptResult{
|
|
ExitCode: ptr.Int64(0),
|
|
Output: exceedsMaxCharsUnsaved,
|
|
},
|
|
expectOutput: fmt.Sprintf(`
|
|
Exit code: 0 (Script ran successfully.)
|
|
|
|
Output:
|
|
|
|
-------------------------------------------------------------------------------------
|
|
|
|
Fleet records the last 10,000 characters to prevent downtime.
|
|
|
|
%s
|
|
|
|
-------------------------------------------------------------------------------------
|
|
`, exceedsMaxCharsUnsaved),
|
|
},
|
|
// TODO: this would take 5 minutes to run, we don't want that kind of slowdown in our test suite
|
|
// but can be useful to have around for manual testing.
|
|
//{
|
|
// name: "host timeout",
|
|
// scriptPath: generateValidPath,
|
|
// expectErrMsg: fleet.RunScriptHostTimeoutErrMsg,
|
|
//},
|
|
{name: "disabled scripts globally", scriptPath: generateValidPath, expectErrMsg: fleet.RunScriptScriptsDisabledGloballyErrMsg},
|
|
}
|
|
|
|
setupDS := func(t *testing.T, c testCase) {
|
|
ds.HostByIdentifierFunc = func(ctx context.Context, ident string) (*fleet.Host, error) {
|
|
if ident != "host1" || c.expectNotFound {
|
|
return nil, ¬FoundError{}
|
|
}
|
|
return &fleet.Host{ID: 42, SeenTime: time.Now(), OrbitNodeKey: ptr.String("abc")}, nil
|
|
}
|
|
ds.HostFunc = func(ctx context.Context, hid uint) (*fleet.Host, error) {
|
|
if hid != 42 || c.expectNotFound {
|
|
return nil, ¬FoundError{}
|
|
}
|
|
h := fleet.Host{ID: hid, SeenTime: time.Now(), OrbitNodeKey: ptr.String("abc")}
|
|
if c.expectOffline {
|
|
h.SeenTime = time.Now().Add(-time.Hour)
|
|
}
|
|
return &h, nil
|
|
}
|
|
ds.ListPendingHostScriptExecutionsFunc = func(ctx context.Context, hid uint) ([]*fleet.HostScriptResult, error) {
|
|
require.Equal(t, uint(42), hid)
|
|
if c.expectPending {
|
|
return []*fleet.HostScriptResult{{HostID: uint(42)}}, nil
|
|
}
|
|
return nil, nil
|
|
}
|
|
ds.GetHostScriptExecutionResultFunc = func(ctx context.Context, execID string) (*fleet.HostScriptResult, error) {
|
|
if c.scriptResult != nil {
|
|
return c.scriptResult, nil
|
|
}
|
|
return &fleet.HostScriptResult{}, nil
|
|
}
|
|
ds.GetHostLockWipeStatusFunc = func(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
|
|
return &fleet.HostLockWipeStatus{}, nil
|
|
}
|
|
ds.NewHostScriptExecutionRequestFunc = func(ctx context.Context, req *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
|
|
require.Equal(t, uint(42), req.HostID)
|
|
return &fleet.HostScriptResult{
|
|
Hostname: "host1",
|
|
HostID: req.HostID,
|
|
ScriptContents: req.ScriptContents,
|
|
}, nil
|
|
}
|
|
if c.name == "disabled scripts globally" {
|
|
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
|
return &fleet.AppConfig{ServerSettings: fleet.ServerSettings{ScriptsDisabled: true}}, nil
|
|
}
|
|
} else {
|
|
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
|
return &fleet.AppConfig{ServerSettings: fleet.ServerSettings{ScriptsDisabled: false}}, nil
|
|
}
|
|
}
|
|
if c.savedScriptContents != nil {
|
|
ds.GetScriptContentsFunc = func(ctx context.Context, id uint) ([]byte, error) {
|
|
return c.savedScriptContents()
|
|
}
|
|
ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) {
|
|
return &fleet.Script{ID: id, Name: "foo"}, nil
|
|
}
|
|
} else {
|
|
ds.GetScriptContentsFunc = func(ctx context.Context, id uint) ([]byte, error) {
|
|
return []byte("echo hello world"), nil
|
|
}
|
|
ds.ScriptFunc = func(ctx context.Context, id uint) (*fleet.Script, error) {
|
|
return &fleet.Script{ID: id, Name: "foo"}, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
setupDS(t, c)
|
|
args := []string{"run-script", "--host", "host1"}
|
|
|
|
if c.scriptPath != nil {
|
|
scriptPath := c.scriptPath()
|
|
defer os.Remove(scriptPath)
|
|
args = append(args, "--script-path", scriptPath)
|
|
}
|
|
|
|
if c.scriptName != "" {
|
|
args = append(args, "--script-name", c.scriptName)
|
|
}
|
|
|
|
if c.teamID != nil {
|
|
args = append(args, "--team", fmt.Sprintf("%d", *c.teamID))
|
|
}
|
|
|
|
b, err := runAppNoChecks(args)
|
|
if c.expectErrMsg != "" {
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), c.expectErrMsg)
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
if c.scriptResult != nil {
|
|
out := b.String()
|
|
require.NoError(t, err)
|
|
require.NotEmpty(t, out)
|
|
require.Equal(t, c.expectOutput, out)
|
|
} else {
|
|
require.Empty(t, b.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func writeTmpScriptContents(t *testing.T, scriptContents string, extension string) string {
|
|
tmpFile, err := os.CreateTemp(t.TempDir(), "*"+extension)
|
|
require.NoError(t, err)
|
|
_, err = tmpFile.WriteString(scriptContents)
|
|
require.NoError(t, err)
|
|
return tmpFile.Name()
|
|
}
|