Add /scripts/run and scripts/run/sync API endpoints to run scripts (part 1) (#13417)

This commit is contained in:
Martin Angers 2023-08-21 14:47:19 -04:00 committed by GitHub
parent 3b61adf7a4
commit de32faefdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1252 additions and 31 deletions

View File

@ -5,7 +5,7 @@ If some of the following don't apply, delete the relevant line.
- [ ] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. - [ ] Changes file added for user-visible changes in `changes/` or `orbit/changes/`.
See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information.
- [ ] Documented any API changes (docs/Using-Fleet/REST-API.md or docs/Contributing/API-for-contributors.md) - [ ] Documented any API changes (docs/Using-Fleet/REST-API.md or docs/Contributing/API-for-contributors.md)
- [ ] Documented any permissions changes - [ ] Documented any permissions changes (docs/Using Fleet/manage-access.md)
- [ ] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [ ] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements)
- [ ] Added support on fleet's osquery simulator `cmd/osquery-perf` for new osquery data ingestion features. - [ ] Added support on fleet's osquery simulator `cmd/osquery-perf` for new osquery data ingestion features.
- [ ] Added/updated tests - [ ] Added/updated tests

View File

@ -0,0 +1 @@
* Added `/scripts/run` and `scripts/run/sync` API endpoints to send a script to be executed on a host (and optionally wait for its results).

View File

@ -10,6 +10,7 @@
- [Policies](#policies) - [Policies](#policies)
- [Queries](#queries) - [Queries](#queries)
- [Schedule (deprecated)](#schedule) - [Schedule (deprecated)](#schedule)
- [Scripts](#scripts)
- [Sessions](#sessions) - [Sessions](#sessions)
- [Software](#software) - [Software](#software)
- [Targets](#targets) - [Targets](#targets)
@ -3126,7 +3127,7 @@ Retrieves the disk encryption key for a host.
} }
``` ```
### Get configuration profiles assigned to a host ### Get configuration profiles assigned to a host
Requires Fleet's MDM properly [enabled and configured](https://fleetdm.com/docs/using-fleet/mdm-setup). Requires Fleet's MDM properly [enabled and configured](https://fleetdm.com/docs/using-fleet/mdm-setup).
@ -5626,7 +5627,7 @@ load balancer timeout.
## Schedule ## Schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
- [Get schedule (deprecated)](#get-schedule) - [Get schedule (deprecated)](#get-schedule)
@ -5641,7 +5642,7 @@ These API routes let you control your scheduled queries.
### Get schedule ### Get schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
`GET /api/v1/fleet/global/schedule` `GET /api/v1/fleet/global/schedule`
@ -5715,7 +5716,7 @@ None.
### Add query to schedule ### Add query to schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
`POST /api/v1/fleet/global/schedule` `POST /api/v1/fleet/global/schedule`
@ -5776,7 +5777,7 @@ None.
### Edit query in schedule ### Edit query in schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
`PATCH /api/v1/fleet/global/schedule/{id}` `PATCH /api/v1/fleet/global/schedule/{id}`
@ -5832,7 +5833,7 @@ None.
### Remove query from schedule ### Remove query from schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
`DELETE /api/v1/fleet/global/schedule/{id}` `DELETE /api/v1/fleet/global/schedule/{id}`
@ -5854,7 +5855,7 @@ None.
### Team schedule ### Team schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
- [Get team schedule (deprecated)](#get-team-schedule) - [Get team schedule (deprecated)](#get-team-schedule)
@ -5866,7 +5867,7 @@ This allows you to easily configure scheduled queries that will impact a whole t
#### Get team schedule #### Get team schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
`GET /api/v1/fleet/teams/{id}/schedule` `GET /api/v1/fleet/teams/{id}/schedule`
@ -5946,7 +5947,7 @@ This allows you to easily configure scheduled queries that will impact a whole t
#### Add query to team schedule #### Add query to team schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
`POST /api/v1/fleet/teams/{id}/schedule` `POST /api/v1/fleet/teams/{id}/schedule`
@ -6004,7 +6005,7 @@ This allows you to easily configure scheduled queries that will impact a whole t
#### Edit query in team schedule #### Edit query in team schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
`PATCH /api/v1/fleet/teams/{team_id}/schedule/{scheduled_query_id}` `PATCH /api/v1/fleet/teams/{team_id}/schedule/{scheduled_query_id}`
@ -6061,7 +6062,7 @@ This allows you to easily configure scheduled queries that will impact a whole t
#### Remove query from team schedule #### Remove query from team schedule
> The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility. > The schedule API endpoints are deprecated as of Fleet 4.35. They are maintained for backwards compatibility.
> Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling. > Please use the [queries](#queries) endpoints, which as of 4.35 have attributes such as `interval` and `platform` that enable scheduling.
`DELETE /api/v1/fleet/teams/{team_id}/schedule/{scheduled_query_id}` `DELETE /api/v1/fleet/teams/{team_id}/schedule/{scheduled_query_id}`
@ -6083,6 +6084,78 @@ This allows you to easily configure scheduled queries that will impact a whole t
--- ---
## Scripts
- [Run script asynchronously](#run-script-asynchronously)
- [Run script synchronously](#run-script-synchronously)
### Run script asynchronously
_Available in Fleet Premium_
Creates a script execution request and returns the execution identifier to retrieve results at a later time.
`POST /api/v1/fleet/scripts/run`
#### Parameters
| Name | Type | In | Description |
| ---- | ------- | ---- | -------------------------------------------- |
| host_id | integer | body | **Required**. The host id to run the script on. |
| script_contents | string | body | **Required**. The contents of the script to run. |
#### Example
`POST /api/v1/fleet/scripts/run`
##### Default response
`Status: 202`
```json
{
"host_id": 1227,
"execution_id": "e797d6c6-3aae-11ee-be56-0242ac120002"
}
```
### Run script synchronously
_Available in Fleet Premium_
Creates a script execution request and waits for a result to return (up to a 1 minute timeout).
`POST /api/v1/fleet/scripts/run/sync`
#### Parameters
| Name | Type | In | Description |
| ---- | ------- | ---- | -------------------------------------------- |
| host_id | integer | body | **Required**. The host id to run the script on. |
| script_contents | string | body | **Required**. The contents of the script to run. |
#### Example
`POST /api/v1/fleet/scripts/run/sync`
##### Default response
`Status: 200`
```json
{
"host_id": 1227,
"execution_id": "e797d6c6-3aae-11ee-be56-0242ac120002",
"script_contents": "echo 'hello'",
"output": "hello",
"runtime": 1,
"exit_code": 0
}
```
---
## Sessions ## Sessions
- [Get session info](#get-session-info) - [Get session info](#get-session-info)

View File

@ -87,6 +87,7 @@ GitOps is an API-only and write-only role that can be used on CI/CD pipelines.
| View metadata of MDM macOS bootstrap packages\* | | | ✅ | ✅ | | | View metadata of MDM macOS bootstrap packages\* | | | ✅ | ✅ | |
| Edit/upload MDM macOS bootstrap packages\* | | | ✅ | ✅ | ✅ | | Edit/upload MDM macOS bootstrap packages\* | | | ✅ | ✅ | ✅ |
| Enable/disable MDM macOS setup end user authentication\* | | | ✅ | ✅ | ✅ | | Enable/disable MDM macOS setup end user authentication\* | | | ✅ | ✅ | ✅ |
| Run scripts on hosts\* | | | ✅ | ✅ | |
\* Applies only to Fleet Premium \* Applies only to Fleet Premium
@ -149,6 +150,7 @@ Users that are members of multiple teams can be assigned different roles for eac
| View metadata of MDM macOS bootstrap packages | | | ✅ | ✅ | | | View metadata of MDM macOS bootstrap packages | | | ✅ | ✅ | |
| Edit/upload MDM macOS bootstrap packages | | | ✅ | ✅ | ✅ | | Edit/upload MDM macOS bootstrap packages | | | ✅ | ✅ | ✅ |
| Enable/disable MDM macOS setup end user authentication | | | ✅ | ✅ | ✅ | | Enable/disable MDM macOS setup end user authentication | | | ✅ | ✅ | ✅ |
| Run scripts on hosts | | | ✅ | ✅ | |
\* Applies only to [Fleet REST API](https://fleetdm.com/docs/using-fleet/rest-api) \* Applies only to [Fleet REST API](https://fleetdm.com/docs/using-fleet/rest-api)
@ -156,4 +158,4 @@ Users that are members of multiple teams can be assigned different roles for eac
<meta name="pageOrderInSection" value="900"> <meta name="pageOrderInSection" value="900">
<meta name="description" value="Learn about the different roles and permissions in Fleet."> <meta name="description" value="Learn about the different roles and permissions in Fleet.">
<meta name="navSection" value="The basics"> <meta name="navSection" value="The basics">

View File

@ -2,7 +2,11 @@ package service
import ( import (
"context" "context"
"fmt"
"time"
"unicode/utf8"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/fleet"
) )
@ -19,3 +23,109 @@ func (svc *Service) HostByIdentifier(ctx context.Context, identifier string, opt
opts.IncludePolicies = true opts.IncludePolicies = true
return svc.Service.HostByIdentifier(ctx, identifier, opts) return svc.Service.HostByIdentifier(ctx, identifier, opts)
} }
func (svc *Service) RunHostScript(ctx context.Context, request *fleet.HostScriptRequestPayload, waitForResult time.Duration) (*fleet.HostScriptResult, error) {
const (
maxScriptRuneLen = 10000
maxPendingScriptAge = time.Minute // any script older than this is not considered pending anymore on that host
)
// must load the host (lite is enough, just for the team) to authorize
// with the proper team id. We cannot first authorize if the user can list
// hosts, because the user could have a write-only role (e.g. gitops).
host, err := svc.ds.HostLite(ctx, request.HostID)
if err != nil {
// if error is because the host does not exist, check first if the user
// had access to run a script (to prevent leaking valid host ids).
if fleet.IsNotFound(err) {
if err := svc.authz.Authorize(ctx, &fleet.HostScriptResult{}, fleet.ActionWrite); err != nil {
return nil, err
}
}
svc.authz.SkipAuthorization(ctx)
return nil, ctxerr.Wrap(ctx, err, "get host lite")
}
if err := svc.authz.Authorize(ctx, &fleet.HostScriptResult{TeamID: host.TeamID}, fleet.ActionWrite); err != nil {
return nil, err
}
if request.ScriptContents == "" {
return nil, fleet.NewInvalidArgumentError("script_contents", "a script to execute is required")
}
// look for the script length in bytes first, as rune counting a huge string
// can be expensive.
if len(request.ScriptContents) > utf8.UTFMax*maxScriptRuneLen {
return nil, fleet.NewInvalidArgumentError("script_contents", fmt.Sprintf("script is too long, must be at most %d characters", maxScriptRuneLen))
}
// now that we know that the script is at most 4*maxScriptRuneLen bytes long,
// we can safely count the runes for a precise check.
if utf8.RuneCountInString(request.ScriptContents) > maxScriptRuneLen {
return nil, fleet.NewInvalidArgumentError("script_contents", fmt.Sprintf("script is too long, must be at most %d characters", maxScriptRuneLen))
}
// TODO(mna): any other validation we want to apply to the script? What is the "must be bash/powershell" check?
pending, err := svc.ds.ListPendingHostScriptExecutions(ctx, request.HostID, maxPendingScriptAge)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "list host pending script executions")
}
if len(pending) > 0 {
// TODO(mna): there are a number of issues with that validation: it only
// really says that there was a script execution _request_ that was made < 1m
// ago, and that blocks executing any more scripts on that host, but the
// host may not even have received the previous script for execution yet,
// so if we accept more scripts after 1m, we may end up having multiple
// scripts to execute on the host at the same time (or more likely in
// sequence, but still). This may be good enough for now, I think the whole
// idea of locking if a script is pending is meant to be temporary anyway.
return nil, fleet.NewInvalidArgumentError("script_contents", "a script is currently executing on the host")
}
// create the script execution request
script, err := svc.ds.NewHostScriptExecutionRequest(ctx, request)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "create script execution request")
}
// TODO(mna): figure out how to send this to the host, either something to do
// here or via the DB checking if there are pending scripts for the host when
// sending queries or notifications.
if waitForResult <= 0 {
// async execution, return
return script, nil
}
ctx, cancel := context.WithTimeout(ctx, waitForResult)
defer cancel()
// if waiting for a result times out, we still want to return the script's
// execution request information along with the error, so that the caller can
// use the execution id for later checks.
timeoutResult := script
checkInterval := time.Second
after := time.NewTimer(checkInterval)
for {
select {
case <-ctx.Done():
return timeoutResult, ctx.Err()
case <-after.C:
result, err := svc.ds.GetHostScriptExecutionResult(ctx, script.ExecutionID)
if err != nil {
// is that due to the context being canceled during the DB access?
if ctxErr := ctx.Err(); ctxErr != nil {
return timeoutResult, ctxErr
}
return nil, ctxerr.Wrap(ctx, err, "get script execution result")
}
if result.ExitCode.Valid {
// a result was received from the host, return
return result, nil
}
// at a second to every attempt, until it reaches 5s (then check every 5s)
if checkInterval < 5*time.Second {
checkInterval += time.Second
}
after.Reset(checkInterval)
}
}
}

View File

@ -350,7 +350,7 @@ allow {
object.observer_can_run == false object.observer_can_run == false
is_null(subject.global_role) is_null(subject.global_role)
action == run action == run
is_null(object.team_id) is_null(object.team_id)
not is_null(object.host_targets.teams) not is_null(object.host_targets.teams)
@ -365,7 +365,7 @@ allow {
object.observer_can_run == false object.observer_can_run == false
is_null(subject.global_role) is_null(subject.global_role)
action == run action == run
team_role(subject, object.team_id) == [admin, maintainer, observer_plus][_] team_role(subject, object.team_id) == [admin, maintainer, observer_plus][_]
not is_null(object.host_targets.teams) not is_null(object.host_targets.teams)
@ -395,7 +395,7 @@ allow {
object.observer_can_run == false object.observer_can_run == false
is_null(subject.global_role) is_null(subject.global_role)
action == run action == run
team_role(subject, object.team_id) == [admin, maintainer, observer_plus][_] team_role(subject, object.team_id) == [admin, maintainer, observer_plus][_]
# there are no team targets # there are no team targets
@ -425,7 +425,7 @@ allow {
object.observer_can_run == true object.observer_can_run == true
is_null(subject.global_role) is_null(subject.global_role)
action == run action == run
is_null(object.team_id) is_null(object.team_id)
not is_null(object.host_targets.teams) not is_null(object.host_targets.teams)
@ -440,7 +440,7 @@ allow {
object.observer_can_run == true object.observer_can_run == true
is_null(subject.global_role) is_null(subject.global_role)
action == run action == run
team_role(subject, object.team_id) == [admin, maintainer, observer_plus, observer][_] team_role(subject, object.team_id) == [admin, maintainer, observer_plus, observer][_]
not is_null(object.host_targets.teams) not is_null(object.host_targets.teams)
@ -454,7 +454,7 @@ allow {
object.observer_can_run == true object.observer_can_run == true
is_null(subject.global_role) is_null(subject.global_role)
action == run action == run
is_null(object.team_id) is_null(object.team_id)
# If role is admin, maintainer, observer_plus or observer on any team. # If role is admin, maintainer, observer_plus or observer on any team.
@ -470,7 +470,7 @@ allow {
object.observer_can_run == true object.observer_can_run == true
is_null(subject.global_role) is_null(subject.global_role)
action == run action == run
team_role(subject, object.team_id) == [admin, maintainer, observer_plus, observer][_] team_role(subject, object.team_id) == [admin, maintainer, observer_plus, observer][_]
# there are no team targets # there are no team targets
@ -814,3 +814,39 @@ allow {
not is_null(subject) not is_null(subject)
action == read action == read
} }
##
# Host Script Result (script execution and output)
##
# Global admins and maintainers can write (execute) scripts (not gitops as this
# is not something that relates to fleetctl apply).
allow {
object.type == "host_script_result"
subject.global_role == [admin, maintainer][_]
action == write
}
# Global admins, maintainers, observer_plus and observers can read scripts.
allow {
object.type == "host_script_result"
subject.global_role == [admin, maintainer, observer, observer_plus][_]
action == read
}
# Team admin and maintainers can write (execute) scripts for their teams (not
# gitops as this is not something that relates to fleetctl apply).
allow {
object.type == "host_script_result"
not is_null(object.team_id)
team_role(subject, object.team_id) == [admin, maintainer][_]
action == write
}
# Team admins, maintainers, observer_plus and observers can read scripts for their teams.
allow {
object.type == "host_script_result"
not is_null(object.team_id)
team_role(subject, object.team_id) == [admin, maintainer, observer_plus, observer][_]
action == read
}

View File

@ -1814,6 +1814,96 @@ func TestAuthorizeMDMAppleCommand(t *testing.T) {
}) })
} }
func TestAuthorizeHostScriptResult(t *testing.T) {
t.Parallel()
globalScript := &fleet.HostScriptResult{}
team1Script := &fleet.HostScriptResult{
TeamID: ptr.Uint(1),
}
runTestCases(t, []authTestCase{
{user: test.UserNoRoles, object: globalScript, action: write, allow: false},
{user: test.UserNoRoles, object: globalScript, action: read, allow: false},
{user: test.UserNoRoles, object: team1Script, action: write, allow: false},
{user: test.UserNoRoles, object: team1Script, action: read, allow: false},
{user: test.UserAdmin, object: globalScript, action: write, allow: true},
{user: test.UserAdmin, object: globalScript, action: read, allow: true},
{user: test.UserAdmin, object: team1Script, action: write, allow: true},
{user: test.UserAdmin, object: team1Script, action: read, allow: true},
{user: test.UserMaintainer, object: globalScript, action: write, allow: true},
{user: test.UserMaintainer, object: globalScript, action: read, allow: true},
{user: test.UserMaintainer, object: team1Script, action: write, allow: true},
{user: test.UserMaintainer, object: team1Script, action: read, allow: true},
{user: test.UserObserver, object: globalScript, action: write, allow: false},
{user: test.UserObserver, object: globalScript, action: read, allow: true},
{user: test.UserObserver, object: team1Script, action: write, allow: false},
{user: test.UserObserver, object: team1Script, action: read, allow: true},
{user: test.UserObserverPlus, object: globalScript, action: write, allow: false},
{user: test.UserObserverPlus, object: globalScript, action: read, allow: true},
{user: test.UserObserverPlus, object: team1Script, action: write, allow: false},
{user: test.UserObserverPlus, object: team1Script, action: read, allow: true},
{user: test.UserGitOps, object: globalScript, action: write, allow: false},
{user: test.UserGitOps, object: globalScript, action: read, allow: false},
{user: test.UserGitOps, object: team1Script, action: write, allow: false},
{user: test.UserGitOps, object: team1Script, action: read, allow: false},
{user: test.UserTeamAdminTeam1, object: globalScript, action: write, allow: false},
{user: test.UserTeamAdminTeam1, object: globalScript, action: read, allow: false},
{user: test.UserTeamAdminTeam1, object: team1Script, action: write, allow: true},
{user: test.UserTeamAdminTeam1, object: team1Script, action: read, allow: true},
{user: test.UserTeamAdminTeam2, object: globalScript, action: write, allow: false},
{user: test.UserTeamAdminTeam2, object: globalScript, action: read, allow: false},
{user: test.UserTeamAdminTeam2, object: team1Script, action: write, allow: false},
{user: test.UserTeamAdminTeam2, object: team1Script, action: read, allow: false},
{user: test.UserTeamMaintainerTeam1, object: globalScript, action: write, allow: false},
{user: test.UserTeamMaintainerTeam1, object: globalScript, action: read, allow: false},
{user: test.UserTeamMaintainerTeam1, object: team1Script, action: write, allow: true},
{user: test.UserTeamMaintainerTeam1, object: team1Script, action: read, allow: true},
{user: test.UserTeamMaintainerTeam2, object: globalScript, action: write, allow: false},
{user: test.UserTeamMaintainerTeam2, object: globalScript, action: read, allow: false},
{user: test.UserTeamMaintainerTeam2, object: team1Script, action: write, allow: false},
{user: test.UserTeamMaintainerTeam2, object: team1Script, action: read, allow: false},
{user: test.UserTeamObserverTeam1, object: globalScript, action: write, allow: false},
{user: test.UserTeamObserverTeam1, object: globalScript, action: read, allow: false},
{user: test.UserTeamObserverTeam1, object: team1Script, action: write, allow: false},
{user: test.UserTeamObserverTeam1, object: team1Script, action: read, allow: true},
{user: test.UserTeamObserverTeam2, object: globalScript, action: write, allow: false},
{user: test.UserTeamObserverTeam2, object: globalScript, action: read, allow: false},
{user: test.UserTeamObserverTeam2, object: team1Script, action: write, allow: false},
{user: test.UserTeamObserverTeam2, object: team1Script, action: read, allow: false},
{user: test.UserTeamObserverPlusTeam1, object: globalScript, action: write, allow: false},
{user: test.UserTeamObserverPlusTeam1, object: globalScript, action: read, allow: false},
{user: test.UserTeamObserverPlusTeam1, object: team1Script, action: write, allow: false},
{user: test.UserTeamObserverPlusTeam1, object: team1Script, action: read, allow: true},
{user: test.UserTeamObserverPlusTeam2, object: globalScript, action: write, allow: false},
{user: test.UserTeamObserverPlusTeam2, object: globalScript, action: read, allow: false},
{user: test.UserTeamObserverPlusTeam2, object: team1Script, action: write, allow: false},
{user: test.UserTeamObserverPlusTeam2, object: team1Script, action: read, allow: false},
{user: test.UserTeamGitOpsTeam1, object: globalScript, action: write, allow: false},
{user: test.UserTeamGitOpsTeam1, object: globalScript, action: read, allow: false},
{user: test.UserTeamGitOpsTeam1, object: team1Script, action: write, allow: false},
{user: test.UserTeamGitOpsTeam1, object: team1Script, action: read, allow: false},
{user: test.UserTeamGitOpsTeam2, object: globalScript, action: write, allow: false},
{user: test.UserTeamGitOpsTeam2, object: globalScript, action: read, allow: false},
{user: test.UserTeamGitOpsTeam2, object: team1Script, action: write, allow: false},
{user: test.UserTeamGitOpsTeam2, object: team1Script, action: read, allow: false},
})
}
func TestJSONToInterfaceUser(t *testing.T) { func TestJSONToInterfaceUser(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -18,6 +18,7 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/log" "github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level" "github.com/go-kit/kit/log/level"
"github.com/google/uuid"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
) )
@ -464,6 +465,7 @@ var hostRefs = []string{
"host_disk_encryption_keys", "host_disk_encryption_keys",
"host_software_installed_paths", "host_software_installed_paths",
"host_dep_assignments", "host_dep_assignments",
"host_script_results",
} }
// those host refs cannot be deleted using the host.id like the hostRefs above, // those host refs cannot be deleted using the host.id like the hostRefs above,
@ -4187,3 +4189,108 @@ func (ds *Datastore) GetMatchingHostSerials(ctx context.Context, serials []strin
return result, nil return result, nil
} }
func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
const (
insStmt = `INSERT INTO host_script_results (host_id, execution_id, script_contents, output) VALUES (?, ?, ?, '')`
getStmt = `SELECT id, host_id, execution_id, script_contents FROM host_script_results WHERE id = ?`
)
execID := uuid.New().String()
result, err := ds.writer(ctx).ExecContext(ctx, insStmt,
request.HostID,
execID,
request.ScriptContents,
)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "new host script execution request")
}
var script fleet.HostScriptResult
id, _ := result.LastInsertId()
if err := ds.writer(ctx).GetContext(ctx, &script, getStmt, id); err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting the created host script result to return")
}
return &script, nil
}
func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *fleet.HostScriptResultPayload) error {
const updStmt = `
UPDATE host_script_results SET
output = ?,
runtime = ?,
exit_code = ?
WHERE
host_id = ? AND
execution_id = ?`
const maxOutputRuneLen = 10000
output := result.Output
if len(output) > utf8.UTFMax*maxOutputRuneLen {
// truncate the bytes as we know the output is too long, no point
// converting more bytes than needed to runes.
output = output[len(output)-utf8.UTFMax*maxOutputRuneLen:]
}
if utf8.RuneCountInString(output) > maxOutputRuneLen {
outputRunes := []rune(output)
output = string(outputRunes[len(outputRunes)-maxOutputRuneLen:])
}
if _, err := ds.writer(ctx).ExecContext(ctx, updStmt,
output,
result.Runtime,
result.ExitCode,
result.HostID,
result.ExecutionID,
); err != nil {
return ctxerr.Wrap(ctx, err, "update host script result")
}
return nil
}
func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint, ignoreOlder time.Duration) ([]*fleet.HostScriptResult, error) {
const listStmt = `
SELECT
id,
host_id,
execution_id,
script_contents
FROM
host_script_results
WHERE
host_id = ? AND
exit_code IS NULL AND
created_at >= DATE_SUB(NOW(), INTERVAL ? SECOND)`
var results []*fleet.HostScriptResult
seconds := int(ignoreOlder.Seconds())
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID, seconds); err != nil {
return nil, ctxerr.Wrap(ctx, err, "list pending host script results")
}
return results, nil
}
func (ds *Datastore) GetHostScriptExecutionResult(ctx context.Context, execID string) (*fleet.HostScriptResult, error) {
const getStmt = `
SELECT
id,
host_id,
execution_id,
script_contents,
output,
runtime,
exit_code
FROM
host_script_results
WHERE
execution_id = ?`
var result fleet.HostScriptResult
if err := sqlx.GetContext(ctx, ds.reader(ctx), &result, getStmt, execID); err != nil {
if err == sql.ErrNoRows {
return nil, ctxerr.Wrap(ctx, notFound("HostScriptResult").WithName(execID))
}
return nil, ctxerr.Wrap(ctx, err, "get host script result")
}
return &result, nil
}

View File

@ -151,6 +151,7 @@ func TestHosts(t *testing.T) {
{"ListHostsLiteByUUIDs", testHostsListHostsLiteByUUIDs}, {"ListHostsLiteByUUIDs", testHostsListHostsLiteByUUIDs},
{"GetMatchingHostSerials", testGetMatchingHostSerials}, {"GetMatchingHostSerials", testGetMatchingHostSerials},
{"ListHostsLiteByIDs", testHostsListHostsLiteByIDs}, {"ListHostsLiteByIDs", testHostsListHostsLiteByIDs},
{"HostScriptResult", testHostScriptResult},
} }
for _, c := range cases { for _, c := range cases {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
@ -5775,6 +5776,8 @@ func testHostsDeleteHosts(t *testing.T, ds *Datastore) {
require.NoError(t, err) require.NoError(t, err)
err = ds.RecordHostBootstrapPackage(context.Background(), "command-uuid", host.UUID) err = ds.RecordHostBootstrapPackage(context.Background(), "command-uuid", host.UUID)
require.NoError(t, err) require.NoError(t, err)
_, err = ds.NewHostScriptExecutionRequest(context.Background(), &fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "foo"})
require.NoError(t, err)
// Check there's an entry for the host in all the associated tables. // Check there's an entry for the host in all the associated tables.
for _, hostRef := range hostRefs { for _, hostRef := range hostRefs {
@ -7263,3 +7266,118 @@ func testHostsListHostsLiteByIDs(t *testing.T, ds *Datastore) {
}) })
} }
} }
func testHostScriptResult(t *testing.T, ds *Datastore) {
ctx := context.Background()
// no script saved yet
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1, time.Second)
require.NoError(t, err)
require.Empty(t, pending)
_, err = ds.GetHostScriptExecutionResult(ctx, "abc")
require.Error(t, err)
var nfe *notFoundError
require.ErrorAs(t, err, &nfe)
// create a createdScript execution request
createdScript, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
})
require.NoError(t, err)
require.NotZero(t, createdScript.ID)
require.NotEmpty(t, createdScript.ExecutionID)
require.Equal(t, uint(1), createdScript.HostID)
require.NotEmpty(t, createdScript.ExecutionID)
require.Equal(t, "echo", createdScript.ScriptContents)
require.False(t, createdScript.ExitCode.Valid)
require.Empty(t, createdScript.Output)
// the script execution is now listed as pending for this host
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, 10*time.Second)
require.NoError(t, err)
require.Len(t, pending, 1)
require.Equal(t, createdScript.ID, pending[0].ID)
// waiting for a second and an ignore of 0s ignores this script
time.Sleep(time.Second)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, 0)
require.NoError(t, err)
require.Empty(t, pending)
// record a result for this execution
err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: createdScript.ExecutionID,
Output: "foo",
Runtime: 2,
ExitCode: 0,
})
require.NoError(t, err)
// it is not pending anymore
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1, 10*time.Second)
require.NoError(t, err)
require.Empty(t, pending)
// the script result can be retrieved
script, err := ds.GetHostScriptExecutionResult(ctx, createdScript.ExecutionID)
require.NoError(t, err)
expectScript := *createdScript
expectScript.Output = "foo"
expectScript.Runtime = 2
expectScript.ExitCode = sql.NullInt64{Int64: 0, Valid: true}
require.Equal(t, &expectScript, script)
// create another script execution request
createdScript, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo2",
})
require.NoError(t, err)
require.NotZero(t, createdScript.ID)
require.NotEmpty(t, createdScript.ExecutionID)
// the script result can be retrieved even if it has no result yet
script, err = ds.GetHostScriptExecutionResult(ctx, createdScript.ExecutionID)
require.NoError(t, err)
require.Equal(t, createdScript, script)
// record a result for this execution, with an output that is too large
largeOutput := strings.Repeat("a", 1000) +
strings.Repeat("b", 1000) +
strings.Repeat("c", 1000) +
strings.Repeat("d", 1000) +
strings.Repeat("e", 1000) +
strings.Repeat("f", 1000) +
strings.Repeat("g", 1000) +
strings.Repeat("h", 1000) +
strings.Repeat("i", 1000) +
strings.Repeat("j", 1000) +
strings.Repeat("k", 1000)
expectedOutput := strings.Repeat("b", 1000) +
strings.Repeat("c", 1000) +
strings.Repeat("d", 1000) +
strings.Repeat("e", 1000) +
strings.Repeat("f", 1000) +
strings.Repeat("g", 1000) +
strings.Repeat("h", 1000) +
strings.Repeat("i", 1000) +
strings.Repeat("j", 1000) +
strings.Repeat("k", 1000)
err = ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: 1,
ExecutionID: createdScript.ExecutionID,
Output: largeOutput,
Runtime: 10,
ExitCode: 1,
})
require.NoError(t, err)
// the script result can be retrieved
script, err = ds.GetHostScriptExecutionResult(ctx, createdScript.ExecutionID)
require.NoError(t, err)
require.Equal(t, expectedOutput, script.Output)
}

View File

@ -0,0 +1,67 @@
package tables
import (
"database/sql"
"fmt"
)
func init() {
MigrationClient.AddMigration(Up_20230814150442, Down_20230814150442)
}
func Up_20230814150442(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TABLE host_script_results (
id INT(10) UNSIGNED NOT NULL AUTO_INCREMENT,
host_id INT(10) UNSIGNED NOT NULL,
-- execution_id is a unique identifier (e.g. UUID) generated for each
-- execution of a script.
execution_id VARCHAR(255) NOT NULL,
-- in the future, we may have a concept of "saved scripts" and in that case
-- the host_script_results may be associated with a script_id instead of
-- the actual script contents. If that's the case, it may be best to allow
-- this field to be NULL (if a saved script is used) but for now we don't
-- support this so I'm making it NOT NULL.
script_contents TEXT NOT NULL,
-- output is the combination of stdout and stderr from the script execution.
output TEXT NOT NULL,
-- runtime is the execution time of the script in seconds, rounded.
runtime INT(10) UNSIGNED NOT NULL DEFAULT 0,
-- the exit code of the script execution, large enough to not assume too
-- much about the possible range (e.g. https://stackoverflow.com/a/328423/1094941)
-- It can be NULL to represent that the script results have not been received
-- yet, and -1 if the script executed but was terminated abruptly (e.g. due to
-- a signal/timeout, same as how Go reports this: https://pkg.go.dev/os#ProcessState.ExitCode).
exit_code INT(10) NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (id),
-- this index can be used to lookup results for a specific
-- execution (execution ids, e.g. when updating the row for results)
UNIQUE KEY idx_host_script_results_execution_id (execution_id),
-- this index can be used to lookup results for a host, to check if a host is currently
-- executing a script (by host_id and with exit_code = NULL), and an created_at condition
-- can be added to dismiss a pending execution that's been running for too long (e.g. host
-- was offline and never sent results, we should eventually start accepting a new
-- script execution).
KEY idx_host_script_results_host_exit_created (host_id, exit_code, created_at)
)`)
if err != nil {
return fmt.Errorf("failed to create host_script_results table: %w", err)
}
return nil
}
func Down_20230814150442(tx *sql.Tx) error {
return nil
}

View File

@ -0,0 +1,88 @@
package tables
import (
"database/sql"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestUp_20230814150442(t *testing.T) {
db := applyUpToPrev(t)
// Apply current migration.
applyNext(t, db)
// NOTE: output field must be provided explicitly (even if empty), because TEXT fields
// cannot have a default value.
insertStmt := `INSERT INTO host_script_results (
host_id, execution_id, script_contents, output
) VALUES (?, ?, ?, '')`
hostID := 123
execID := uuid.New().String()
scriptContents := "echo 'hello world'"
res, err := db.Exec(insertStmt, hostID, execID, scriptContents)
require.NoError(t, err)
id, _ := res.LastInsertId()
require.Greater(t, id, int64(0))
type hostScriptResult struct {
ID int `db:"id"`
HostID int `db:"host_id"`
ExecutionID string `db:"execution_id"`
ScriptContents string `db:"script_contents"`
Output string `db:"output"`
Runtime int `db:"runtime"`
ExitCode sql.NullInt64 `db:"exit_code"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}
// load the host we just created
var scriptResult hostScriptResult
selectStmt := `SELECT id, host_id, execution_id, script_contents, output, runtime, exit_code, created_at, updated_at
FROM host_script_results
WHERE id = ?`
err = db.Get(&scriptResult, selectStmt, id)
require.NoError(t, err)
require.Equal(t, int(id), scriptResult.ID)
require.Equal(t, hostID, scriptResult.HostID)
require.Equal(t, execID, scriptResult.ExecutionID)
require.Equal(t, scriptContents, scriptResult.ScriptContents)
require.Empty(t, scriptResult.Output)
require.Zero(t, scriptResult.Runtime)
require.False(t, scriptResult.ExitCode.Valid)
require.NotZero(t, scriptResult.CreatedAt)
require.NotZero(t, scriptResult.UpdatedAt)
// check pending executions for a given host
var countPending int
countPendingStmt := `SELECT COUNT(*)
FROM host_script_results
WHERE host_id = ? AND exit_code IS NULL`
err = db.Get(&countPending, countPendingStmt, hostID)
require.NoError(t, err)
require.Equal(t, 1, countPending)
// update the host we just created
output := `hello world`
runtime := 10
exitCode := int64(0)
updateStmt := `UPDATE host_script_results SET output = ?, runtime = ?, exit_code = ? WHERE host_id = ? AND execution_id = ?`
_, err = db.Exec(updateStmt, output, runtime, exitCode, hostID, execID)
require.NoError(t, err)
// reload the updated host result
err = db.Get(&scriptResult, selectStmt, id)
require.NoError(t, err)
require.Equal(t, output, scriptResult.Output)
require.Equal(t, runtime, scriptResult.Runtime)
require.True(t, scriptResult.ExitCode.Valid)
require.Equal(t, exitCode, scriptResult.ExitCode.Int64)
}

File diff suppressed because one or more lines are too long

View File

@ -1009,6 +1009,25 @@ type Datastore interface {
MDMWindowsInsertEnrolledDevice(ctx context.Context, device *MDMWindowsEnrolledDevice) error MDMWindowsInsertEnrolledDevice(ctx context.Context, device *MDMWindowsEnrolledDevice) error
// MDMWindowsDeleteEnrolledDevice deletes a give MDMWindowsEnrolledDevice entry from the database using the device id. // MDMWindowsDeleteEnrolledDevice deletes a give MDMWindowsEnrolledDevice entry from the database using the device id.
MDMWindowsDeleteEnrolledDevice(ctx context.Context, mdmDeviceID string) error MDMWindowsDeleteEnrolledDevice(ctx context.Context, mdmDeviceID string) error
///////////////////////////////////////////////////////////////////////////////
// Host Script Results
// NewHostScriptExecutionRequest creates a new host script result entry with
// just the script to run information (result is not yet available).
NewHostScriptExecutionRequest(ctx context.Context, request *HostScriptRequestPayload) (*HostScriptResult, error)
// SetHostScriptExecutionResult stores the result of a host script execution.
SetHostScriptExecutionResult(ctx context.Context, result *HostScriptResultPayload) error
// GetHostScriptExecutionResult returns the result of a host script
// execution. It returns the host script results even if no results have been
// received, it is the caller's responsibility to check if that was the case
// (with ExitCode being null).
GetHostScriptExecutionResult(ctx context.Context, execID string) (*HostScriptResult, error)
// ListPendingHostScriptExecutions returns all the pending host script
// executions, which are those that have yet to record a result. Entries
// older than the ignoreOlder duration are ignored, considered too old to be
// pending.
ListPendingHostScriptExecutions(ctx context.Context, hostID uint, ignoreOlder time.Duration) ([]*HostScriptResult, error)
} }
const ( const (

View File

@ -311,31 +311,43 @@ func (e *MDMNotConfiguredError) Error() string {
return "MDM features aren't turned on in Fleet. For more information about setting up MDM, please visit https://fleetdm.com/docs/using-fleet/mobile-device-management" return "MDM features aren't turned on in Fleet. For more information about setting up MDM, please visit https://fleetdm.com/docs/using-fleet/mobile-device-management"
} }
// BadGatewayError is an error type that generates a 502 status code. // GatewayError is an error type that generates a 502 or 504 status code.
type BadGatewayError struct { type GatewayError struct {
Message string Message string
err error err error
code int
ErrorWithUUID ErrorWithUUID
} }
// NewBadGatewayError returns a MDMBadGatewayError with the message and // NewBadGatewayError returns a GatewayError with the message and
// error specified. // error specified and that returns a 502 status code.
func NewBadGatewayError(message string, err error) *BadGatewayError { func NewBadGatewayError(message string, err error) *GatewayError {
return &BadGatewayError{ return &GatewayError{
Message: message, Message: message,
err: err, err: err,
code: http.StatusBadGateway,
}
}
// NewGatewayTimeoutError returns a GatewayError with the message and
// error specified and that returns a 504 status code.
func NewGatewayTimeoutError(message string, err error) *GatewayError {
return &GatewayError{
Message: message,
err: err,
code: http.StatusGatewayTimeout,
} }
} }
// StatusCode implements the kithttp.StatusCoder interface so we can customize the // StatusCode implements the kithttp.StatusCoder interface so we can customize the
// HTTP status code of the response returning this error. // HTTP status code of the response returning this error.
func (e *BadGatewayError) StatusCode() int { func (e *GatewayError) StatusCode() int {
return http.StatusBadGateway return e.code
} }
// Error returns the error message. // Error returns the error message.
func (e *BadGatewayError) Error() string { func (e *GatewayError) Error() string {
msg := e.Message msg := e.Message
if e.err != nil { if e.err != nil {
msg += ": " + e.err.Error() msg += ": " + e.err.Error()

View File

@ -2,6 +2,7 @@ package fleet
import ( import (
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -1071,3 +1072,47 @@ type HostMacOSProfile struct {
// InstallDate is the date the profile was installed on the host as reported by the host's clock. // InstallDate is the date the profile was installed on the host as reported by the host's clock.
InstallDate time.Time `json:"install_date" db:"install_date"` InstallDate time.Time `json:"install_date" db:"install_date"`
} }
type HostScriptRequestPayload struct {
HostID uint `json:"host_id"`
ScriptContents string `json:"script_contents"`
}
type HostScriptResultPayload struct {
HostID uint `json:"host_id"`
ExecutionID string `json:"execution_id"`
Output string `json:"output"`
Runtime int `json:"runtime"`
ExitCode int `json:"exit_code"`
}
// HostScriptResult represents a script result that was requested to execute on
// a specific host. If no result was received yet for a script, the ExitCode
// field is null and the output is empty.
type HostScriptResult struct {
// ID is the unique row identifier of the host script result.
ID uint `json:"-" db:"id"`
// HostID is the host on which the script was executed.
HostID uint `json:"host_id" db:"host_id"`
// ExecutionID is a unique identifier for a single execution of the script.
ExecutionID string `json:"execution_id" db:"execution_id"`
// ScriptContents is the content of the script to execute.
ScriptContents string `json:"script_contents" db:"script_contents"`
// Output is the combined stdout/stderr output of the script. It is empty
// if no result was received yet.
Output string `json:"output" db:"output"`
// Runtime is the running time of the script in seconds, rounded.
Runtime int `json:"runtime" db:"runtime"`
// ExitCode is null if script execution result was never received from the
// host. It is -1 if it was received but the script did not terminate
// normally (same as how Go handles this: https://pkg.go.dev/os#ProcessState.ExitCode)
ExitCode sql.NullInt64 `json:"exit_code" db:"exit_code"`
// TeamID is only used for authorization, it must be set to the team id of
// the host when checking authorization and is otherwise not set.
TeamID *uint `json:"team_id" db:"-"`
}
func (hsr HostScriptResult) AuthzType() string {
return "host_script_result"
}

View File

@ -787,4 +787,12 @@ type Service interface {
// GetMDMWindowsTOSContent returns TOS content // GetMDMWindowsTOSContent returns TOS content
GetMDMWindowsTOSContent(ctx context.Context, redirectUri string, reqID string) (string, error) GetMDMWindowsTOSContent(ctx context.Context, redirectUri string, reqID string) (string, error)
///////////////////////////////////////////////////////////////////////////////
// Host Script Execution
// RunHostScript executes a script on a host and optionally waits for the
// result if waitForResult is > 0. If it times out waiting for a result, it
// fails with a 504 Gateway Timeout error.
RunHostScript(ctx context.Context, request *HostScriptRequestPayload, waitForResult time.Duration) (*HostScriptResult, error)
} }

View File

@ -664,6 +664,14 @@ type MDMWindowsInsertEnrolledDeviceFunc func(ctx context.Context, device *fleet.
type MDMWindowsDeleteEnrolledDeviceFunc func(ctx context.Context, mdmDeviceID string) error type MDMWindowsDeleteEnrolledDeviceFunc func(ctx context.Context, mdmDeviceID string) error
type NewHostScriptExecutionRequestFunc func(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error)
type SetHostScriptExecutionResultFunc func(ctx context.Context, result *fleet.HostScriptResultPayload) error
type GetHostScriptExecutionResultFunc func(ctx context.Context, execID string) (*fleet.HostScriptResult, error)
type ListPendingHostScriptExecutionsFunc func(ctx context.Context, hostID uint, ignoreOlder time.Duration) ([]*fleet.HostScriptResult, error)
type DataStore struct { type DataStore struct {
HealthCheckFunc HealthCheckFunc HealthCheckFunc HealthCheckFunc
HealthCheckFuncInvoked bool HealthCheckFuncInvoked bool
@ -1634,6 +1642,18 @@ type DataStore struct {
MDMWindowsDeleteEnrolledDeviceFunc MDMWindowsDeleteEnrolledDeviceFunc MDMWindowsDeleteEnrolledDeviceFunc MDMWindowsDeleteEnrolledDeviceFunc
MDMWindowsDeleteEnrolledDeviceFuncInvoked bool MDMWindowsDeleteEnrolledDeviceFuncInvoked bool
NewHostScriptExecutionRequestFunc NewHostScriptExecutionRequestFunc
NewHostScriptExecutionRequestFuncInvoked bool
SetHostScriptExecutionResultFunc SetHostScriptExecutionResultFunc
SetHostScriptExecutionResultFuncInvoked bool
GetHostScriptExecutionResultFunc GetHostScriptExecutionResultFunc
GetHostScriptExecutionResultFuncInvoked bool
ListPendingHostScriptExecutionsFunc ListPendingHostScriptExecutionsFunc
ListPendingHostScriptExecutionsFuncInvoked bool
mu sync.Mutex mu sync.Mutex
} }
@ -3897,3 +3917,31 @@ func (s *DataStore) MDMWindowsDeleteEnrolledDevice(ctx context.Context, mdmDevic
s.mu.Unlock() s.mu.Unlock()
return s.MDMWindowsDeleteEnrolledDeviceFunc(ctx, mdmDeviceID) return s.MDMWindowsDeleteEnrolledDeviceFunc(ctx, mdmDeviceID)
} }
func (s *DataStore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
s.mu.Lock()
s.NewHostScriptExecutionRequestFuncInvoked = true
s.mu.Unlock()
return s.NewHostScriptExecutionRequestFunc(ctx, request)
}
func (s *DataStore) SetHostScriptExecutionResult(ctx context.Context, result *fleet.HostScriptResultPayload) error {
s.mu.Lock()
s.SetHostScriptExecutionResultFuncInvoked = true
s.mu.Unlock()
return s.SetHostScriptExecutionResultFunc(ctx, result)
}
func (s *DataStore) GetHostScriptExecutionResult(ctx context.Context, execID string) (*fleet.HostScriptResult, error) {
s.mu.Lock()
s.GetHostScriptExecutionResultFuncInvoked = true
s.mu.Unlock()
return s.GetHostScriptExecutionResultFunc(ctx, execID)
}
func (s *DataStore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint, ignoreOlder time.Duration) ([]*fleet.HostScriptResult, error) {
s.mu.Lock()
s.ListPendingHostScriptExecutionsFuncInvoked = true
s.mu.Unlock()
return s.ListPendingHostScriptExecutionsFunc(ctx, hostID, ignoreOlder)
}

View File

@ -439,6 +439,9 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ue.GET("/api/_version_/fleet/status/result_store", statusResultStoreEndpoint, nil) ue.GET("/api/_version_/fleet/status/result_store", statusResultStoreEndpoint, nil)
ue.GET("/api/_version_/fleet/status/live_query", statusLiveQueryEndpoint, nil) ue.GET("/api/_version_/fleet/status/live_query", statusLiveQueryEndpoint, nil)
ue.POST("/api/_version_/fleet/scripts/run", runScriptEndpoint, runScriptRequest{})
ue.POST("/api/_version_/fleet/scripts/run/sync", runScriptSyncEndpoint, runScriptRequest{})
// Only Fleet MDM specific endpoints should be within the root /mdm/ path. // Only Fleet MDM specific endpoints should be within the root /mdm/ path.
// NOTE: remember to update // NOTE: remember to update
// `service.mdmAppleConfigurationRequiredEndpoints` when you add an // `service.mdmAppleConfigurationRequiredEndpoints` when you add an

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/csv" "encoding/csv"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -1588,3 +1589,99 @@ func (svc *Service) HostEncryptionKey(ctx context.Context, id uint) (*fleet.Host
return key, nil return key, nil
} }
////////////////////////////////////////////////////////////////////////////////
// Run Script on a Host (async)
////////////////////////////////////////////////////////////////////////////////
type runScriptRequest struct {
HostID uint `json:"host_id"`
ScriptContents string `json:"script_contents"`
}
type runScriptResponse struct {
Err error `json:"error,omitempty"`
HostID uint `json:"host_id,omitempty"`
ExecutionID string `json:"execution_id,omitempty"`
}
func (r runScriptResponse) error() error { return r.Err }
func (r runScriptResponse) Status() int { return http.StatusAccepted }
func runScriptEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*runScriptRequest)
var noWait time.Duration
result, err := svc.RunHostScript(ctx, &fleet.HostScriptRequestPayload{
HostID: req.HostID,
ScriptContents: req.ScriptContents,
}, noWait)
if err != nil {
return runScriptResponse{Err: err}, nil
}
return runScriptResponse{HostID: result.HostID, ExecutionID: result.ExecutionID}, nil
}
////////////////////////////////////////////////////////////////////////////////
// Run Script on a Host (sync)
////////////////////////////////////////////////////////////////////////////////
type runScriptSyncResponse struct {
Err error `json:"error,omitempty"`
*fleet.HostScriptResult
// only set if the error was a timeout waiting for a result
ErrorMessage string `json:"error_message,omitempty"`
}
func (r runScriptSyncResponse) error() error { return r.Err }
func (r runScriptSyncResponse) Status() int {
if r.ErrorMessage != "" {
return http.StatusGatewayTimeout
}
return http.StatusOK
}
// this is to be used only by tests, to be able to use a shorter timeout.
var testRunScriptWaitForResult time.Duration
func runScriptSyncEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
waitForResult := time.Minute
if testRunScriptWaitForResult != 0 {
waitForResult = testRunScriptWaitForResult
}
req := request.(*runScriptRequest)
result, err := svc.RunHostScript(ctx, &fleet.HostScriptRequestPayload{
HostID: req.HostID,
ScriptContents: req.ScriptContents,
}, waitForResult)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
err = fleet.NewGatewayTimeoutError("script execution timed out waiting for a result", err)
// it should still return the execution id and host id in this situation,
// so the user knows what script request to look at in the UI. We cannot
// return an error (field Err) in this case, as the errorer interface's
// rendering logic would take over and only render the error part of the
// response struct. This is why we use the distinct ErrorMessage field to
// add the error message and status code to the response, along with the
// script request.
return runScriptSyncResponse{
HostScriptResult: result,
ErrorMessage: err.Error(),
}, nil
}
return runScriptSyncResponse{Err: err}, nil
}
return runScriptSyncResponse{
HostScriptResult: result,
}, nil
}
func (svc *Service) RunHostScript(ctx context.Context, request *fleet.HostScriptRequestPayload, waitForResult time.Duration) (*fleet.HostScriptResult, error) {
// skipauth: No authorization check needed due to implementation returning
// only license error.
svc.authz.SkipAuthorization(ctx)
return nil, fleet.ErrMissingLicense
}

View File

@ -1080,3 +1080,155 @@ func TestHostMDMProfileDetail(t *testing.T) {
}) })
} }
} }
func TestHostRunScript(t *testing.T) {
ds := new(mock.Store)
license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)}
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: license, SkipCreateTestUsers: true})
// use a custom implementation of checkAuthErr as the service call will fail
// with a not found error for unknown host in case of authorization success,
// and the package-wide checkAuthErr requires no error.
checkAuthErr := func(t *testing.T, shouldFail bool, err error) {
if shouldFail {
require.Error(t, err)
require.Equal(t, (&authz.Forbidden{}).Error(), err.Error())
} else if err != nil {
require.NotEqual(t, (&authz.Forbidden{}).Error(), err.Error())
}
}
teamHost := &fleet.Host{ID: 1, Hostname: "host-team", TeamID: ptr.Uint(1)}
noTeamHost := &fleet.Host{ID: 2, Hostname: "host-no-team", TeamID: nil}
nonExistingHost := &fleet.Host{ID: 3, Hostname: "no-such-host", TeamID: nil}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.HostLiteFunc = func(ctx context.Context, hostID uint) (*fleet.Host, error) {
if hostID == 1 {
return teamHost, nil
}
if hostID == 2 {
return noTeamHost, nil
}
return nil, newNotFoundError()
}
ds.NewHostScriptExecutionRequestFunc = func(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
return &fleet.HostScriptResult{HostID: request.HostID, ScriptContents: request.ScriptContents, ExecutionID: "abc"}, nil
}
ds.ListPendingHostScriptExecutionsFunc = func(ctx context.Context, hostID uint, ignoreOlder time.Duration) ([]*fleet.HostScriptResult, error) {
return nil, nil
}
testCases := []struct {
name string
user *fleet.User
shouldFailTeamWrite bool
shouldFailGlobalWrite bool
}{
{
name: "global admin",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
shouldFailTeamWrite: false,
shouldFailGlobalWrite: false,
},
{
name: "global maintainer",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
shouldFailTeamWrite: false,
shouldFailGlobalWrite: false,
},
{
name: "global observer",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "global observer+",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserverPlus)},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "global gitops",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleGitOps)},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "team admin, belongs to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
shouldFailTeamWrite: false,
shouldFailGlobalWrite: true,
},
{
name: "team maintainer, belongs to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
shouldFailTeamWrite: false,
shouldFailGlobalWrite: true,
},
{
name: "team observer, belongs to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "team observer+, belongs to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserverPlus}}},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "team gitops, belongs to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleGitOps}}},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "team admin, DOES NOT belong to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleAdmin}}},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "team maintainer, DOES NOT belong to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "team observer, DOES NOT belong to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "team observer+, DOES NOT belong to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserverPlus}}},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
{
name: "team gitops, DOES NOT belong to team",
user: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleGitOps}}},
shouldFailTeamWrite: true,
shouldFailGlobalWrite: true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ctx = viewer.NewContext(ctx, viewer.Viewer{User: tt.user})
_, err := svc.RunHostScript(ctx, &fleet.HostScriptRequestPayload{HostID: noTeamHost.ID, ScriptContents: "abc"}, 0)
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
_, err = svc.RunHostScript(ctx, &fleet.HostScriptRequestPayload{HostID: teamHost.ID, ScriptContents: "abc"}, 0)
checkAuthErr(t, tt.shouldFailTeamWrite, err)
// a non-existing host is authorized as for global write (because we can't know what team it belongs to)
_, err = svc.RunHostScript(ctx, &fleet.HostScriptRequestPayload{HostID: nonExistingHost.ID, ScriptContents: "abc"}, 0)
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
})
}
}

View File

@ -4579,6 +4579,13 @@ func (s *integrationTestSuite) TestPremiumEndpointsWithoutLicense() {
// device migrate mdm endpoint returns an error if not premium // device migrate mdm endpoint returns an error if not premium
createHostAndDeviceToken(t, s.ds, "some-token") createHostAndDeviceToken(t, s.ds, "some-token")
s.Do("POST", fmt.Sprintf("/api/v1/fleet/device/%s/migrate_mdm", "some-token"), nil, http.StatusPaymentRequired) s.Do("POST", fmt.Sprintf("/api/v1/fleet/device/%s/migrate_mdm", "some-token"), nil, http.StatusPaymentRequired)
// run a script
var runResp runScriptResponse
s.DoJSON("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: 1}, http.StatusPaymentRequired, &runResp)
// run a script sync
s.DoJSON("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: 1}, http.StatusPaymentRequired, &runResp)
} }
// TestGlobalPoliciesBrowsing tests that team users can browse (read) global policies (see #3722). // TestGlobalPoliciesBrowsing tests that team users can browse (read) global policies (see #3722).

View File

@ -3651,3 +3651,123 @@ func (s *integrationEnterpriseTestSuite) TestDesktopEndpointWithInvalidPolicy()
require.NoError(t, desktopRes.Err) require.NoError(t, desktopRes.Err)
require.Equal(t, uint(0), *desktopRes.FailingPolicies) require.Equal(t, uint(0), *desktopRes.FailingPolicies)
} }
func (s *integrationEnterpriseTestSuite) TestRunHostScript() {
t := s.T()
testRunScriptWaitForResult = 2 * time.Second
defer func() { testRunScriptWaitForResult = 0 }()
ctx := context.Background()
host := createOrbitEnrolledHost(t, "linux", "", s.ds)
// attempt to run a script on a non-existing host
var runResp runScriptResponse
s.DoJSON("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID + 100, ScriptContents: "echo"}, http.StatusNotFound, &runResp)
// attempt to run an empty script
res := s.Do("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: ""}, http.StatusUnprocessableEntity)
errMsg := extractServerErrorText(res.Body)
require.Contains(t, errMsg, "a script to execute is required")
// attempt to run an overly long script
res = s.Do("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: strings.Repeat("a", 10001)}, http.StatusUnprocessableEntity)
errMsg = extractServerErrorText(res.Body)
require.Contains(t, errMsg, "script is too long")
// create a valid script execution request
s.DoJSON("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo"}, http.StatusAccepted, &runResp)
require.Equal(t, host.ID, runResp.HostID)
require.NotEmpty(t, runResp.ExecutionID)
result, err := s.ds.GetHostScriptExecutionResult(ctx, runResp.ExecutionID)
require.NoError(t, err)
require.Equal(t, host.ID, result.HostID)
require.Equal(t, "echo", result.ScriptContents)
require.False(t, result.ExitCode.Valid)
// attempt to run a sync script on a non-existing host
var runSyncResp runScriptSyncResponse
s.DoJSON("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID + 100, ScriptContents: "echo"}, http.StatusNotFound, &runSyncResp)
// attempt to sync run an empty script
res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: ""}, http.StatusUnprocessableEntity)
errMsg = extractServerErrorText(res.Body)
require.Contains(t, errMsg, "a script to execute is required")
// attempt to sync run an overly long script
res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: strings.Repeat("a", 10001)}, http.StatusUnprocessableEntity)
errMsg = extractServerErrorText(res.Body)
require.Contains(t, errMsg, "script is too long")
// attempt to create a valid sync script execution request, fails because the
// host has a pending script execution
res = s.Do("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo"}, http.StatusUnprocessableEntity)
errMsg = extractServerErrorText(res.Body)
require.Contains(t, errMsg, "a script is currently executing on the host")
// simulate a result being returned for the pending script
err = s.ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host.ID,
ExecutionID: runResp.ExecutionID,
ExitCode: 0,
Output: "ok",
})
require.NoError(t, err)
// create a valid sync script execution request, fails because the
// request will time-out waiting for a result.
runSyncResp = runScriptSyncResponse{}
s.DoJSON("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo"}, http.StatusGatewayTimeout, &runSyncResp)
require.Equal(t, host.ID, runSyncResp.HostID)
require.NotEmpty(t, runSyncResp.ExecutionID)
require.Contains(t, runSyncResp.ErrorMessage, "script execution timed out waiting for a result")
// simulate a result being returned for that pending script
err = s.ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host.ID,
ExecutionID: runSyncResp.ExecutionID,
ExitCode: 0,
Output: "ok",
})
require.NoError(t, err)
// create a valid sync script execution request, and simulate a result
// arriving before timeout.
testRunScriptWaitForResult = 5 * time.Second
ctx, cancel := context.WithTimeout(ctx, testRunScriptWaitForResult)
defer cancel()
go func() {
for range time.Tick(300 * time.Millisecond) {
pending, err := s.ds.ListPendingHostScriptExecutions(ctx, host.ID, 10*time.Second)
if err != nil {
t.Log(err)
return
}
if len(pending) > 0 {
// ignoring errors in this goroutine, the HTTP request below will fail if this fails
err = s.ds.SetHostScriptExecutionResult(ctx, &fleet.HostScriptResultPayload{
HostID: host.ID,
ExecutionID: pending[0].ExecutionID,
Output: "ok",
Runtime: 1,
ExitCode: 0,
})
if err != nil {
t.Log(err)
}
return
}
}
}()
runSyncResp = runScriptSyncResponse{}
s.DoJSON("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo"}, http.StatusOK, &runSyncResp)
require.Equal(t, host.ID, runSyncResp.HostID)
require.NotEmpty(t, runSyncResp.ExecutionID)
require.Equal(t, "ok", runSyncResp.Output)
require.Equal(t, int64(0), runSyncResp.ExitCode.Int64)
require.Empty(t, runSyncResp.ErrorMessage)
}

View File

@ -1261,6 +1261,7 @@ func createHostThenEnrollMDM(ds fleet.Datastore, fleetServerURL string, t *testi
func (s *integrationMDMTestSuite) TestDEPProfileAssignment() { func (s *integrationMDMTestSuite) TestDEPProfileAssignment() {
t := s.T() t := s.T()
ctx := context.Background() ctx := context.Background()
devices := []godep.Device{ devices := []godep.Device{
{SerialNumber: uuid.New().String(), Model: "MacBook Pro", OS: "osx", OpType: "added"}, {SerialNumber: uuid.New().String(), Model: "MacBook Pro", OS: "osx", OpType: "added"},