Add labels to the fleetd extensions feature (#14008)

#13287

- [X] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- ~[ ] Documented any API changes (docs/Using-Fleet/REST-API.md or
docs/Contributing/API-for-contributors.md)~
- ~[ ] Documented any permissions changes (docs/Using
Fleet/manage-access.md)~
- ~[ ] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)~
- ~[ ] Added support on fleet's osquery simulator `cmd/osquery-perf` for
new osquery data ingestion features.~
- [X] Added/updated tests
- [X] Manual QA for all new/changed functionality
  - ~For Orbit and Fleet Desktop changes:~
- ~[ ] Manual QA must be performed in the three main OSs, macOS, Windows
and Linux.~
- ~[ ] Auto-update manual QA, from released version of component to new
version (see [tools/tuf/test](../tools/tuf/test/README.md)).~
This commit is contained in:
Lucas Manuel Rodriguez 2023-09-22 11:09:09 -03:00 committed by GitHub
parent 89a6b538b7
commit 4bdef5dbe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 608 additions and 16 deletions

View File

@ -0,0 +1 @@
* Add `labels` to the fleetd extensions feature to allow deploying extensions to hosts that belong to certain labels.

View File

@ -259,7 +259,7 @@ func (svc *Service) ModifyTeamAgentOptions(ctx context.Context, teamID uint, tea
}
if teamOptions != nil {
if err := fleet.ValidateJSONAgentOptions(teamOptions); err != nil {
if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, teamOptions, true); err != nil {
err = fleet.NewUserMessageError(err, http.StatusBadRequest)
if applyOptions.Force && !applyOptions.DryRun {
level.Info(svc.logger).Log("err", err, "msg", "force-apply team agent options with validation errors")
@ -694,7 +694,7 @@ func (svc *Service) ApplyTeamSpecs(ctx context.Context, specs []*fleet.TeamSpec,
}
if len(spec.AgentOptions) > 0 && !bytes.Equal(spec.AgentOptions, jsonNull) {
if err := fleet.ValidateJSONAgentOptions(spec.AgentOptions); err != nil {
if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, spec.AgentOptions, true); err != nil {
err = fleet.NewUserMessageError(err, http.StatusBadRequest)
if applyOpts.Force && !applyOpts.DryRun {
level.Info(svc.logger).Log("err", err, "msg", "force-apply team agent options with validation errors")

View File

@ -233,6 +233,8 @@ func (r *ExtensionRunner) DoExtensionConfigUpdate() (bool, error) {
}
}
log.Debug().Str("extensions", string(config.Extensions)).Msg("received extensions configuration")
var extensions fleet.Extensions
err = json.Unmarshal(config.Extensions, &extensions)
if err != nil {

View File

@ -928,3 +928,31 @@ func (ds *Datastore) LabelsSummary(ctx context.Context) ([]*fleet.LabelSummary,
}
return labelsSummary, nil
}
// HostMemberOfAllLabels returns whether the given host is a member of all the provided labels.
// If the labels do not exist, then the host is considered not a member of the provided labels.
// A host will always be a member of an empty label set, so this method returns (true, nil)
// if labelNames is empty.
func (ds *Datastore) HostMemberOfAllLabels(ctx context.Context, hostID uint, labelNames []string) (bool, error) {
if len(labelNames) == 0 {
return true, nil
}
sqlStatement := `
SELECT COUNT(*) = ? FROM labels l
LEFT JOIN (SELECT label_id FROM label_membership WHERE host_id = ?) lm
ON l.id = lm.label_id
WHERE l.name IN (?) AND lm.label_id IS NOT NULL;
`
sql, args, err := sqlx.In(sqlStatement, len(labelNames), hostID, labelNames)
if err != nil {
return false, ctxerr.Wrap(ctx, err, "building query to get label IDs")
}
var ok bool
if err := sqlx.GetContext(ctx, ds.reader(ctx), &ok, sql, args...); err != nil {
return false, ctxerr.Wrap(ctx, err, "get label IDs")
}
return ok, nil
}

View File

@ -65,6 +65,7 @@ func TestLabels(t *testing.T) {
{"LabelsSummary", testLabelsSummary},
{"ListHostsInLabelFailingPolicies", testListHostsInLabelFailingPolicies},
{"ListHostsInLabelDiskEncryptionStatus", testListHostsInLabelDiskEncryptionStatus},
{"HostMemberOfAllLabels", testHostMemberOfAllLabels},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -1142,3 +1143,185 @@ func testListHostsInLabelDiskEncryptionStatus(t *testing.T, ds *Datastore) {
listHostsCheckCount(t, ds, fleet.TeamFilter{User: test.UserAdmin}, fleet.HostListOptions{MacOSSettingsDiskEncryptionFilter: fleet.DiskEncryptionFailed}, 2)
listHostsCheckCount(t, ds, fleet.TeamFilter{User: test.UserAdmin}, fleet.HostListOptions{MacOSSettingsDiskEncryptionFilter: fleet.DiskEncryptionRemovingEnforcement}, 1)
}
func testHostMemberOfAllLabels(t *testing.T, ds *Datastore) {
ctx := context.Background()
//
// Setup test
// - h1 member of 'All hosts', 'Foobar' and 'Zoobar'
// - h2 member of 'All hosts' and 'Foobar'
// - h3 member of 'All hosts' and 'Zoobar'
// - h4 member of 'All hosts'
// - h5 member of no labels
//
allHostsLabel, err := ds.NewLabel(ctx,
&fleet.Label{
Name: "All hosts",
Query: "SELECT 1",
LabelType: fleet.LabelTypeBuiltIn,
LabelMembershipType: fleet.LabelMembershipTypeDynamic,
},
)
require.NoError(t, err)
foobarLabel, err := ds.NewLabel(ctx, &fleet.Label{
Name: "Foobar",
Query: "SELECT 1;",
LabelType: fleet.LabelTypeRegular,
LabelMembershipType: fleet.LabelMembershipTypeDynamic,
})
require.NoError(t, err)
zoobarLabel, err := ds.NewLabel(ctx, &fleet.Label{
Name: "Zoobar",
Query: "SELECT 2;",
LabelType: fleet.LabelTypeRegular,
LabelMembershipType: fleet.LabelMembershipTypeDynamic,
})
require.NoError(t, err)
newHostFunc := func(name string) *fleet.Host {
h, err := ds.NewHost(ctx, &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
OsqueryHostID: ptr.String(name),
NodeKey: ptr.String(name),
UUID: name,
Hostname: "foo.local" + name,
})
require.NoError(t, err)
return h
}
h1 := newHostFunc("h1")
h2 := newHostFunc("h2")
h3 := newHostFunc("h3")
h4 := newHostFunc("h4")
h5 := newHostFunc("h5")
_ = h5
err = ds.RecordLabelQueryExecutions(ctx, h1, map[uint]*bool{
allHostsLabel.ID: ptr.Bool(true),
foobarLabel.ID: ptr.Bool(true),
zoobarLabel.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
err = ds.RecordLabelQueryExecutions(ctx, h2, map[uint]*bool{
allHostsLabel.ID: ptr.Bool(true),
foobarLabel.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
err = ds.RecordLabelQueryExecutions(ctx, h3, map[uint]*bool{
allHostsLabel.ID: ptr.Bool(true),
zoobarLabel.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
err = ds.RecordLabelQueryExecutions(ctx, h4, map[uint]*bool{
allHostsLabel.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
//
// Run tests for HostMemberOfAllLabels
//
for _, tc := range []struct {
name string
hostID uint
labelNames []string
expectedResult bool
}{
{
name: "nonexistent host",
hostID: 999,
labelNames: []string{allHostsLabel.Name},
expectedResult: false,
},
{
name: "h1 does not belong to nonexistent label",
hostID: h1.ID,
labelNames: []string{"Non existent label"},
expectedResult: false,
},
{
name: "h1 does not belong to All hosts + nonexistent label",
hostID: h1.ID,
labelNames: []string{allHostsLabel.Name, "Non existent label"},
expectedResult: false,
},
{
name: "h1 belongs to the given subset of labels",
hostID: h1.ID,
labelNames: []string{allHostsLabel.Name, foobarLabel.Name},
expectedResult: true,
},
{
name: "h1 belongs to all the given labels",
hostID: h1.ID,
labelNames: []string{allHostsLabel.Name, foobarLabel.Name, zoobarLabel.Name},
expectedResult: true,
},
{
name: "h1 member of empty label set",
hostID: h1.ID,
labelNames: []string{},
expectedResult: true,
},
{
name: "h2 belongs to all the given labels",
hostID: h2.ID,
labelNames: []string{allHostsLabel.Name, foobarLabel.Name},
expectedResult: true,
},
{
name: "h2 does not belongs to all the given labels",
hostID: h2.ID,
labelNames: []string{allHostsLabel.Name, foobarLabel.Name, zoobarLabel.Name},
expectedResult: false,
},
{
name: "h2 belongs to the given label",
hostID: h2.ID,
labelNames: []string{foobarLabel.Name},
expectedResult: true,
},
{
name: "h2 does not belong to the given label",
hostID: h2.ID,
labelNames: []string{zoobarLabel.Name},
expectedResult: false,
},
{
name: "h3 belongs to all the given labels",
hostID: h3.ID,
labelNames: []string{allHostsLabel.Name, zoobarLabel.Name},
expectedResult: true,
},
{
name: "h4 belongs to all the given labels",
hostID: h4.ID,
labelNames: []string{allHostsLabel.Name},
expectedResult: true,
},
{
name: "h4 does not belong to the given labels",
hostID: h4.ID,
labelNames: []string{foobarLabel.Name},
expectedResult: false,
},
{
name: "h5 does not belong to the given labels",
hostID: h5.ID,
labelNames: []string{allHostsLabel.Name},
expectedResult: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
v, err := ds.HostMemberOfAllLabels(ctx, tc.hostID, tc.labelNames)
require.NoError(t, err)
require.Equal(t, tc.expectedResult, v)
})
}
}

View File

@ -2,6 +2,7 @@ package fleet
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
@ -37,7 +38,7 @@ func (o *AgentOptions) ForPlatform(platform string) json.RawMessage {
// Options payload. It ensures that all fields are known and have valid values.
// The validation always uses the most recent Osquery version that is available
// at the time of the Fleet release.
func ValidateJSONAgentOptions(rawJSON json.RawMessage) error {
func ValidateJSONAgentOptions(ctx context.Context, ds Datastore, rawJSON json.RawMessage, isPremium bool) error {
var opts AgentOptions
if err := JSONStrictDecode(bytes.NewReader(rawJSON), &opts); err != nil {
return err
@ -55,6 +56,7 @@ func ValidateJSONAgentOptions(rawJSON json.RawMessage) error {
return fmt.Errorf("common config: %w", err)
}
}
for platform, platformOpts := range opts.Overrides.Platforms {
if len(platformOpts) > 0 {
if err := validateJSONAgentOptionsSet(platformOpts); err != nil {
@ -62,6 +64,38 @@ func ValidateJSONAgentOptions(rawJSON json.RawMessage) error {
}
}
}
if len(opts.Extensions) > 0 {
if err := validateJSONAgentOptionsExtensions(ctx, ds, opts.Extensions, isPremium); err != nil {
return err
}
}
return nil
}
func validateJSONAgentOptionsExtensions(ctx context.Context, ds Datastore, optsExtensions json.RawMessage, isPremium bool) error {
var extensions map[string]ExtensionInfo
if err := json.Unmarshal(optsExtensions, &extensions); err != nil {
return fmt.Errorf("unmarshal extensions: %w", err)
}
for _, extensionInfo := range extensions {
if !isPremium && len(extensionInfo.Labels) != 0 {
// Setting labels settings in the extensions config is premium only.
return ErrMissingLicense
}
for _, labelName := range extensionInfo.Labels {
switch _, err := ds.GetLabelSpec(ctx, labelName); {
case err == nil:
// OK
case IsNotFound(err):
// Label does not exist, fail the request.
return fmt.Errorf("Label %q does not exist", labelName)
default:
return fmt.Errorf("get label by name: %w", err)
}
}
}
return nil
}

View File

@ -1,6 +1,7 @@
package fleet
import (
"context"
"errors"
"testing"
@ -140,7 +141,7 @@ func TestValidateAgentOptions(t *testing.T) {
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
err := ValidateJSONAgentOptions([]byte(c.in))
err := ValidateJSONAgentOptions(context.Background(), nil, []byte(c.in), true)
t.Logf("%T", errors.Unwrap(err))
if c.wantErr != "" {
require.ErrorContains(t, err, c.wantErr)

View File

@ -230,6 +230,12 @@ type Datastore interface {
// HostIDsByOSID retrieves the IDs of all host for the given OS ID
HostIDsByOSID(ctx context.Context, osID uint, offset int, limit int) ([]uint, error)
// HostMemberOfAllLabels returns whether the given host is a member of all the provided labels.
// If a label name does not exist, then the host is considered not a member of the provided label.
// A host will always be a member of an empty label set, so this method returns (true, nil)
// if labelNames is empty.
HostMemberOfAllLabels(ctx context.Context, hostID uint, labelNames []string) (bool, error)
// TODO JUAN: Refactor this to use the Operating System type instead.
// HostIDsByOSVersion retrieves the IDs of all host matching osVersion
HostIDsByOSVersion(ctx context.Context, osVersion OSVersion, offset int, limit int) ([]uint, error)

View File

@ -57,6 +57,8 @@ type ExtensionInfo struct {
Platform string `json:"platform"`
// Channel is the select TUF channel to listen for updates.
Channel string `json:"channel"`
// Labels are the label names the host must be member of to run this extension.
Labels []string `json:"labels,omitempty"`
}
// Extensions holds a set of extensions to apply to an Orbit client.

View File

@ -174,6 +174,8 @@ type HostIDsByNameFunc func(ctx context.Context, filter fleet.TeamFilter, hostna
type HostIDsByOSIDFunc func(ctx context.Context, osID uint, offset int, limit int) ([]uint, error)
type HostMemberOfAllLabelsFunc func(ctx context.Context, hostID uint, labelNames []string) (bool, error)
type HostIDsByOSVersionFunc func(ctx context.Context, osVersion fleet.OSVersion, offset int, limit int) ([]uint, error)
type HostByIdentifierFunc func(ctx context.Context, identifier string) (*fleet.Host, error)
@ -913,6 +915,9 @@ type DataStore struct {
HostIDsByOSIDFunc HostIDsByOSIDFunc
HostIDsByOSIDFuncInvoked bool
HostMemberOfAllLabelsFunc HostMemberOfAllLabelsFunc
HostMemberOfAllLabelsFuncInvoked bool
HostIDsByOSVersionFunc HostIDsByOSVersionFunc
HostIDsByOSVersionFuncInvoked bool
@ -2218,6 +2223,13 @@ func (s *DataStore) HostIDsByOSID(ctx context.Context, osID uint, offset int, li
return s.HostIDsByOSIDFunc(ctx, osID, offset, limit)
}
func (s *DataStore) HostMemberOfAllLabels(ctx context.Context, hostID uint, labelNames []string) (bool, error) {
s.mu.Lock()
s.HostMemberOfAllLabelsFuncInvoked = true
s.mu.Unlock()
return s.HostMemberOfAllLabelsFunc(ctx, hostID, labelNames)
}
func (s *DataStore) HostIDsByOSVersion(ctx context.Context, osVersion fleet.OSVersion, offset int, limit int) ([]uint, error) {
s.mu.Lock()
s.HostIDsByOSVersionFuncInvoked = true

View File

@ -329,7 +329,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle
if newAppConfig.AgentOptions != nil {
// if there were Agent Options in the new app config, then it replaced the
// agent options in the resulting app config, so validate those.
if err := fleet.ValidateJSONAgentOptions(*appConfig.AgentOptions); err != nil {
if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, license.IsPremium()); err != nil {
err = fleet.NewUserMessageError(err, http.StatusBadRequest)
if applyOpts.Force && !applyOpts.DryRun {
level.Info(svc.logger).Log("err", err, "msg", "force-apply appConfig agent options with validation errors")

View File

@ -7610,6 +7610,122 @@ func (s *integrationTestSuite) TestDirectIngestSoftwareWithInvalidFields() {
require.NotZero(t, wiresharkSoftware.ID)
}
func (s *integrationTestSuite) TestOrbitConfigExtensions() {
t := s.T()
ctx := context.Background()
appCfg, err := s.ds.AppConfig(ctx)
require.NoError(t, err)
defer func() {
err = s.ds.SaveAppConfig(ctx, appCfg)
require.NoError(t, err)
}()
// Orbit client gets no extensions if extensions are not configured.
orbitLinuxClient := createOrbitEnrolledHost(t, "linux", "foobar1", s.ds)
resp := orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitLinuxClient.OrbitNodeKey)), http.StatusOK, &resp)
require.Empty(t, resp.Extensions)
// Attempt to add extensions (should succeed).
s.DoRaw("PATCH", "/api/latest/fleet/config", []byte(`{
"agent_options": {
"config": {
"options": {
"pack_delimiter": "/",
"logger_tls_period": 10,
"distributed_plugin": "tls",
"disable_distributed": false,
"logger_tls_endpoint": "/api/osquery/log",
"distributed_interval": 10,
"distributed_tls_max_attempts": 3
}
},
"extensions": {
"hello_world_linux": {
"channel": "stable",
"platform": "linux"
},
"hello_mars_linux": {
"channel": "stable",
"platform": "linux"
},
"hello_world_macos": {
"channel": "stable",
"platform": "macos"
}
}
}
}`), http.StatusOK)
// Attempt to add labels to extensions (only available on premium).
s.DoRaw("PATCH", "/api/latest/fleet/config", []byte(`{
"agent_options": {
"config": {
"options": {
"pack_delimiter": "/",
"logger_tls_period": 10,
"distributed_plugin": "tls",
"disable_distributed": false,
"logger_tls_endpoint": "/api/osquery/log",
"distributed_interval": 10,
"distributed_tls_max_attempts": 3
}
},
"extensions": {
"hello_world_linux": {
"channel": "stable",
"platform": "linux"
},
"hello_world_macos": {
"labels": [
"All hosts",
"Some label"
],
"channel": "stable",
"platform": "macos"
},
"hello_world_windows": {
"channel": "stable",
"platform": "windows"
}
}
}
}`), http.StatusBadRequest)
// Orbit client gets extensions configured for its platform.
orbitDarwinClient := createOrbitEnrolledHost(t, "darwin", "foobar2", s.ds)
resp = orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitDarwinClient.OrbitNodeKey)), http.StatusOK, &resp)
require.JSONEq(t, `{
"hello_world_macos": {
"platform": "macos",
"channel": "stable"
}
}`, string(resp.Extensions))
orbitWindowsClient := createOrbitEnrolledHost(t, "windows", "foobar3", s.ds)
// Orbit client gets no extensions if none of the platforms target it.
resp = orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitWindowsClient.OrbitNodeKey)), http.StatusOK, &resp)
require.Empty(t, resp.Extensions)
// Orbit client gets the two extensions configured for its platform.
resp = orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitLinuxClient.OrbitNodeKey)), http.StatusOK, &resp)
require.JSONEq(t, `{
"hello_world_linux": {
"channel": "stable",
"platform": "linux"
},
"hello_mars_linux": {
"channel": "stable",
"platform": "linux"
}
}`, string(resp.Extensions))
}
func (s *integrationTestSuite) TestHostsReportWithPolicyResults() {
t := s.T()
ctx := context.Background()

View File

@ -3924,3 +3924,185 @@ func (s *integrationEnterpriseTestSuite) TestRunHostScript() {
errMsg = extractServerErrorText(res.Body)
require.Contains(t, errMsg, fleet.RunScriptHostOfflineErrMsg)
}
func (s *integrationEnterpriseTestSuite) TestOrbitConfigExtensions() {
t := s.T()
ctx := context.Background()
appCfg, err := s.ds.AppConfig(ctx)
require.NoError(t, err)
defer func() {
err = s.ds.SaveAppConfig(ctx, appCfg)
require.NoError(t, err)
}()
foobarLabel, err := s.ds.NewLabel(ctx, &fleet.Label{
Name: "Foobar",
Query: "SELECT 1;",
})
require.NoError(t, err)
zoobarLabel, err := s.ds.NewLabel(ctx, &fleet.Label{
Name: "Zoobar",
Query: "SELECT 1;",
})
require.NoError(t, err)
allHostsLabel, err := s.ds.GetLabelSpec(ctx, "All hosts")
require.NoError(t, err)
orbitDarwinClient := createOrbitEnrolledHost(t, "darwin", "foobar1", s.ds)
orbitLinuxClient := createOrbitEnrolledHost(t, "linux", "foobar2", s.ds)
orbitWindowsClient := createOrbitEnrolledHost(t, "windows", "foobar3", s.ds)
// orbitDarwinClient is member of 'All hosts' and 'Zoobar' labels.
err = s.ds.RecordLabelQueryExecutions(ctx, orbitDarwinClient, map[uint]*bool{
allHostsLabel.ID: ptr.Bool(true),
zoobarLabel.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
// orbitLinuxClient is member of 'All hosts' and 'Foobar' labels.
err = s.ds.RecordLabelQueryExecutions(ctx, orbitLinuxClient, map[uint]*bool{
allHostsLabel.ID: ptr.Bool(true),
foobarLabel.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
// orbitWindowsClient is member of the 'All hosts' label only.
err = s.ds.RecordLabelQueryExecutions(ctx, orbitWindowsClient, map[uint]*bool{
allHostsLabel.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
// Attempt to add labels to extensions.
s.DoRaw("PATCH", "/api/latest/fleet/config", []byte(`{
"agent_options": {
"config": {
"options": {
"pack_delimiter": "/",
"logger_tls_period": 10,
"distributed_plugin": "tls",
"disable_distributed": false,
"logger_tls_endpoint": "/api/osquery/log",
"distributed_interval": 10,
"distributed_tls_max_attempts": 3
}
},
"extensions": {
"hello_world_linux": {
"labels": [
"All hosts",
"Foobar"
],
"channel": "stable",
"platform": "linux"
},
"hello_world_macos": {
"labels": [
"All hosts",
"Foobar"
],
"channel": "stable",
"platform": "macos"
},
"hello_mars_macos": {
"labels": [
"All hosts",
"Zoobar"
],
"channel": "stable",
"platform": "macos"
},
"hello_world_windows": {
"labels": [
"Zoobar"
],
"channel": "stable",
"platform": "windows"
},
"hello_mars_windows": {
"labels": [
"Foobar"
],
"channel": "stable",
"platform": "windows"
}
}
}
}`), http.StatusOK)
resp := orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitDarwinClient.OrbitNodeKey)), http.StatusOK, &resp)
require.JSONEq(t, `{
"hello_mars_macos": {
"channel": "stable",
"platform": "macos"
}
}`, string(resp.Extensions))
resp = orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitLinuxClient.OrbitNodeKey)), http.StatusOK, &resp)
require.JSONEq(t, `{
"hello_world_linux": {
"channel": "stable",
"platform": "linux"
}
}`, string(resp.Extensions))
resp = orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitWindowsClient.OrbitNodeKey)), http.StatusOK, &resp)
require.Empty(t, string(resp.Extensions))
// orbitDarwinClient is now also a member of the 'Foobar' label.
err = s.ds.RecordLabelQueryExecutions(ctx, orbitDarwinClient, map[uint]*bool{
foobarLabel.ID: ptr.Bool(true),
}, time.Now(), false)
require.NoError(t, err)
resp = orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitDarwinClient.OrbitNodeKey)), http.StatusOK, &resp)
require.JSONEq(t, `{
"hello_world_macos": {
"channel": "stable",
"platform": "macos"
},
"hello_mars_macos": {
"channel": "stable",
"platform": "macos"
}
}`, string(resp.Extensions))
// orbitLinuxClient is no longer a member of the 'Foobar' label.
err = s.ds.RecordLabelQueryExecutions(ctx, orbitLinuxClient, map[uint]*bool{
foobarLabel.ID: nil,
}, time.Now(), false)
require.NoError(t, err)
resp = orbitGetConfigResponse{}
s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitLinuxClient.OrbitNodeKey)), http.StatusOK, &resp)
require.Empty(t, string(resp.Extensions))
// Attempt to set non-existent labels in the config.
s.DoRaw("PATCH", "/api/latest/fleet/config", []byte(`{
"agent_options": {
"config": {
"options": {
"pack_delimiter": "/",
"logger_tls_period": 10,
"distributed_plugin": "tls",
"disable_distributed": false,
"logger_tls_endpoint": "/api/osquery/log",
"distributed_interval": 10,
"distributed_tls_max_attempts": 3
}
},
"extensions": {
"hello_world_linux": {
"labels": [
"All hosts",
"Doesn't exist"
],
"channel": "stable",
"platform": "linux"
}
}
}
}`), http.StatusBadRequest)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
hostctx "github.com/fleetdm/fleet/v4/server/contexts/host"
"github.com/fleetdm/fleet/v4/server/contexts/license"
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/log/level"
@ -168,8 +169,6 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro
// this is not a user-authenticated endpoint
svc.authz.SkipAuthorization(ctx)
var notifs fleet.OrbitConfigNotifications
host, ok := hostctx.FromContext(ctx)
if !ok {
return fleet.OrbitConfig{}, fleet.OrbitError{Message: "internal error: missing host from request context"}
@ -181,6 +180,7 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro
}
// set the host's orbit notifications for macOS MDM
var notifs fleet.OrbitConfigNotifications
if appConfig.MDM.EnabledAndConfigured && host.IsOsqueryEnrolled() {
// TODO(mna): all those notifications implied a macos hosts, but none of
// the checks enforce that (only indirectly in some cases, like
@ -250,7 +250,7 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro
}
}
extensionsFiltered, err := filterExtensionsByPlatform(opts.Extensions, host.Platform)
extensionsFiltered, err := svc.filterExtensionsForHost(ctx, opts.Extensions, host)
if err != nil {
return fleet.OrbitConfig{}, err
}
@ -286,7 +286,7 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro
}
}
extensionsFiltered, err := filterExtensionsByPlatform(opts.Extensions, host.Platform)
extensionsFiltered, err := svc.filterExtensionsForHost(ctx, opts.Extensions, host)
if err != nil {
return fleet.OrbitConfig{}, err
}
@ -308,20 +308,45 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro
}, nil
}
// filterExtensionsByPlatform filters a extensions configuration depending on the host platform.
// (to not send extensions targeted to other operating systems).
func filterExtensionsByPlatform(extensions json.RawMessage, hostPlatform string) (json.RawMessage, error) {
// filterExtensionsForHost filters a extensions configuration depending on the host platform and label membership.
//
// If all extensions are filtered, then it returns (nil, nil) (Orbit expects empty extensions if there
// are no extensions for the host.)
func (svc *Service) filterExtensionsForHost(ctx context.Context, extensions json.RawMessage, host *fleet.Host) (json.RawMessage, error) {
if len(extensions) == 0 {
return extensions, nil
return nil, nil
}
var extensionsInfo fleet.Extensions
if err := json.Unmarshal(extensions, &extensionsInfo); err != nil {
return nil, err
return nil, ctxerr.Wrap(ctx, err, "unmarshal extensions config")
}
// Filter the extensions by platform.
extensionsInfo.FilterByHostPlatform(host.Platform)
// Filter the extensions by labels (premium only feature).
if license, _ := license.FromContext(ctx); license != nil && license.IsPremium() {
for extensionName, extensionInfo := range extensionsInfo {
hostIsMemberOfAllLabels, err := svc.ds.HostMemberOfAllLabels(ctx, host.ID, extensionInfo.Labels)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "check host labels")
}
if hostIsMemberOfAllLabels {
// Do not filter out, but there's no need to send the label names to the devices.
extensionInfo.Labels = nil
extensionsInfo[extensionName] = extensionInfo
} else {
delete(extensionsInfo, extensionName)
}
}
}
// Orbit expects empty message if no extensions apply.
if len(extensionsInfo) == 0 {
return nil, nil
}
extensionsInfo.FilterByHostPlatform(hostPlatform)
extensionsFiltered, err := json.Marshal(extensionsInfo)
if err != nil {
return nil, err
return nil, ctxerr.Wrap(ctx, err, "marshal extensions config")
}
return extensionsFiltered, nil
}