only include policies in device endpoints for premium users (#6077)

This removes policy information from `GET /api/_version_/fleet/device/{token}` from non-premium Fleet instances.

Starting the server with `./build/fleet serve --dev --dev_license`

```bash
$ curl -s https://localhost:8080/api/latest/fleet/device/1804e808-171f-4dda-9bec-f695b2f2371a | jq '.host.policies // "not present"'
[
  {
    "id": 3,
    "name": "Antivirus healthy (Linux)",
    "query": "SELECT score FROM (SELECT case when COUNT(*) = 2 then 1 ELSE 0 END AS score FROM processes WHERE (name = 'clamd') OR (name = 'freshclam')) WHERE score == 1;",
    "description": "Checks that both ClamAV's daemon and its updater service (freshclam) are running.",
    "author_id": 1,
    "author_name": "Roberto",
    "author_email": "test@example.com",
    "team_id": null,
    "resolution": "Ensure ClamAV and Freshclam are installed and running.",
    "platform": "darwin,linux",
    "created_at": "2022-05-23T20:53:36Z",
    "updated_at": "2022-06-03T13:17:42Z",
    "response": ""
  }
]
```

Starting the server with `./build/fleet serve --dev`

```bash
$ curl -s https://localhost:8080/api/latest/fleet/device/1804e808-171f-4dda-9bec-f695b2f2371a | jq '.host.policies // "not present"'
"not present"
```
This commit is contained in:
Roberto Dip 2022-06-07 13:27:13 -03:00 committed by GitHub
parent d3d07a12d9
commit 19f995d3b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 106 additions and 38 deletions

View File

@ -0,0 +1 @@
* Modified `GET /api/_version_/fleet/device/{token}` to only include policies for premium users.

View File

@ -6,12 +6,16 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet"
)
func (svc *Service) GetHost(ctx context.Context, id uint, includeCVEScores bool) (*fleet.HostDetail, error) {
// reuse GetHost, but include cve scores in premium version
return svc.Service.GetHost(ctx, id, true)
func (svc *Service) GetHost(ctx context.Context, id uint, opts fleet.HostDetailOptions) (*fleet.HostDetail, error) {
// reuse GetHost, but include premium details
opts.IncludeCVEScores = true
opts.IncludePolicies = true
return svc.Service.GetHost(ctx, id, opts)
}
func (svc *Service) HostByIdentifier(ctx context.Context, identifier string, includeCVEScores bool) (*fleet.HostDetail, error) {
// reuse HostByIdentifier, but include cve scores in premium version
return svc.Service.HostByIdentifier(ctx, identifier, true)
func (svc *Service) HostByIdentifier(ctx context.Context, identifier string, opts fleet.HostDetailOptions) (*fleet.HostDetail, error) {
// reuse HostByIdentifier, but include premium options
opts.IncludeCVEScores = true
opts.IncludePolicies = true
return svc.Service.HostByIdentifier(ctx, identifier, opts)
}

View File

@ -155,7 +155,7 @@ type HostDetail struct {
// Packs is the list of packs the host is a member of.
Packs []*Pack `json:"packs"`
// Policies is the list of policies and whether it passes for the host
Policies []*HostPolicy `json:"policies"`
Policies *[]*HostPolicy `json:"policies,omitempty"`
}
const (
@ -324,3 +324,8 @@ type OSVersion struct {
Name string `json:"name"`
Platform string `json:"platform"`
}
type HostDetailOptions struct {
IncludeCVEScores bool
IncludePolicies bool
}

View File

@ -239,12 +239,20 @@ type Service interface {
AuthenticateDevice(ctx context.Context, authToken string) (host *Host, debug bool, err error)
ListHosts(ctx context.Context, opt HostListOptions) (hosts []*Host, err error)
GetHost(ctx context.Context, id uint, includeCVEScores bool) (host *HostDetail, err error)
// GetHost returns the host with the provided ID.
//
// The return value can also include policy information and CVE scores based
// on the values provided to `opts`
GetHost(ctx context.Context, id uint, opts HostDetailOptions) (host *HostDetail, err error)
GetHostSummary(ctx context.Context, teamID *uint, platform *string) (summary *HostSummary, err error)
DeleteHost(ctx context.Context, id uint) (err error)
// HostByIdentifier returns one host matching the provided identifier. Possible matches can be on
// osquery_host_identifier, node_key, UUID, or hostname.
HostByIdentifier(ctx context.Context, identifier string, includeCVEScores bool) (*HostDetail, error)
// HostByIdentifier returns one host matching the provided identifier.
// Possible matches can be on osquery_host_identifier, node_key, UUID, or
// hostname.
//
// The return value can also include policy information and CVE scores based
// on the values provided to `opts`
HostByIdentifier(ctx context.Context, identifier string, opts HostDetailOptions) (*HostDetail, error)
// RefetchHost requests a refetch of host details for the provided host.
RefetchHost(ctx context.Context, id uint) (err error)
// AddHostsToTeam adds hosts to an existing team, clearing their team settings if teamID is nil.

View File

@ -37,7 +37,11 @@ func getDeviceHostEndpoint(ctx context.Context, request interface{}, svc fleet.S
}
// must still load the full host details, as it returns more information
hostDetails, err := svc.GetHost(ctx, host.ID, false)
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
}
hostDetails, err := svc.GetHost(ctx, host.ID, opts)
if err != nil {
return getDeviceHostResponse{Err: err}, nil
}

View File

@ -255,7 +255,11 @@ func (r getHostResponse) error() error { return r.Err }
func getHostEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*getHostRequest)
host, err := svc.GetHost(ctx, req.ID, false)
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: true, // intentionally true to preserve existing behavior
}
host, err := svc.GetHost(ctx, req.ID, opts)
if err != nil {
return getHostResponse{Err: err}, nil
}
@ -268,7 +272,7 @@ func getHostEndpoint(ctx context.Context, request interface{}, svc fleet.Service
return getHostResponse{Host: resp}, nil
}
func (svc *Service) GetHost(ctx context.Context, id uint, includeCVEScores bool) (*fleet.HostDetail, error) {
func (svc *Service) GetHost(ctx context.Context, id uint, opts fleet.HostDetailOptions) (*fleet.HostDetail, error) {
alreadyAuthd := svc.authz.IsAuthenticatedWith(ctx, authz.AuthnDeviceToken)
if !alreadyAuthd {
// First ensure the user has access to list hosts, then check the specific
@ -290,7 +294,7 @@ func (svc *Service) GetHost(ctx context.Context, id uint, includeCVEScores bool)
}
}
hostDetails, err := svc.getHostDetails(ctx, host, includeCVEScores)
hostDetails, err := svc.getHostDetails(ctx, host, opts)
if err != nil {
return nil, err
}
@ -393,7 +397,11 @@ type hostByIdentifierRequest struct {
func hostByIdentifierEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*hostByIdentifierRequest)
host, err := svc.HostByIdentifier(ctx, req.Identifier, false)
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: true, // intentionally true to preserve existing behavior
}
host, err := svc.HostByIdentifier(ctx, req.Identifier, opts)
if err != nil {
return getHostResponse{Err: err}, nil
}
@ -408,7 +416,7 @@ func hostByIdentifierEndpoint(ctx context.Context, request interface{}, svc flee
}, nil
}
func (svc *Service) HostByIdentifier(ctx context.Context, identifier string, includeCVEScores bool) (*fleet.HostDetail, error) {
func (svc *Service) HostByIdentifier(ctx context.Context, identifier string, opts fleet.HostDetailOptions) (*fleet.HostDetail, error) {
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
return nil, err
}
@ -423,7 +431,7 @@ func (svc *Service) HostByIdentifier(ctx context.Context, identifier string, inc
return nil, err
}
hostDetails, err := svc.getHostDetails(ctx, host, includeCVEScores)
hostDetails, err := svc.getHostDetails(ctx, host, opts)
if err != nil {
return nil, err
}
@ -612,8 +620,8 @@ func (svc *Service) RefetchHost(ctx context.Context, id uint) error {
return nil
}
func (svc *Service) getHostDetails(ctx context.Context, host *fleet.Host, includeCVEScores bool) (*fleet.HostDetail, error) {
if err := svc.ds.LoadHostSoftware(ctx, host, includeCVEScores); err != nil {
func (svc *Service) getHostDetails(ctx context.Context, host *fleet.Host, opts fleet.HostDetailOptions) (*fleet.HostDetail, error) {
if err := svc.ds.LoadHostSoftware(ctx, host, opts.IncludeCVEScores); err != nil {
return nil, ctxerr.Wrap(ctx, err, "load host software")
}
@ -627,9 +635,18 @@ func (svc *Service) getHostDetails(ctx context.Context, host *fleet.Host, includ
return nil, ctxerr.Wrap(ctx, err, "get packs for host")
}
policies, err := svc.ds.ListPoliciesForHost(ctx, host)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get policies for host")
var policies *[]*fleet.HostPolicy
if opts.IncludePolicies {
hp, err := svc.ds.ListPoliciesForHost(ctx, host)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get policies for host")
}
if hp == nil {
hp = []*fleet.HostPolicy{}
}
policies = &hp
}
return &fleet.HostDetail{Host: *host, Labels: labels, Packs: packs, Policies: policies}, nil

View File

@ -50,7 +50,11 @@ func TestHostDetails(t *testing.T) {
return nil, nil
}
hostDetail, err := svc.getHostDetails(test.UserContext(test.UserAdmin), host, false)
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
}
hostDetail, err := svc.getHostDetails(test.UserContext(test.UserAdmin), host, opts)
require.NoError(t, err)
assert.Equal(t, expectedLabels, hostDetail.Labels)
assert.Equal(t, expectedPacks, hostDetail.Packs)
@ -182,17 +186,21 @@ func TestHostAuth(t *testing.T) {
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
opts := fleet.HostDetailOptions{
IncludeCVEScores: false,
IncludePolicies: false,
}
_, err := svc.GetHost(ctx, 1, false)
_, err := svc.GetHost(ctx, 1, opts)
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, err = svc.HostByIdentifier(ctx, "1", false)
_, err = svc.HostByIdentifier(ctx, "1", opts)
checkAuthErr(t, tt.shouldFailTeamRead, err)
_, err = svc.GetHost(ctx, 2, false)
_, err = svc.GetHost(ctx, 2, opts)
checkAuthErr(t, tt.shouldFailGlobalRead, err)
_, err = svc.HostByIdentifier(ctx, "2", false)
_, err = svc.HostByIdentifier(ctx, "2", opts)
checkAuthErr(t, tt.shouldFailGlobalRead, err)
err = svc.DeleteHost(ctx, 1)
@ -478,12 +486,12 @@ func TestEmptyTeamOSVersions(t *testing.T) {
assert.Empty(t, vers.OSVersions)
// team does not exist
vers, err = svc.OSVersions(test.UserContext(test.UserAdmin), ptr.Uint(3), ptr.String("darwin"))
_, err = svc.OSVersions(test.UserContext(test.UserAdmin), ptr.Uint(3), ptr.String("darwin"))
require.Error(t, err)
require.Equal(t, "not found", fmt.Sprint(err))
// some unknown error
vers, err = svc.OSVersions(test.UserContext(test.UserAdmin), ptr.Uint(4), ptr.String("darwin"))
_, err = svc.OSVersions(test.UserContext(test.UserAdmin), ptr.Uint(4), ptr.String("darwin"))
require.Error(t, err)
require.Equal(t, "some unknown error", fmt.Sprint(err))
}

View File

@ -1728,12 +1728,13 @@ func (s *integrationTestSuite) TestHostDetailsPolicies() {
require.NoError(t, err)
require.Nil(t, r.Err)
hd := r.Host.HostDetail
require.Len(t, hd.Policies, 2)
require.True(t, reflect.DeepEqual(gpResp.Policy.PolicyData, hd.Policies[0].PolicyData))
require.Equal(t, hd.Policies[0].Response, "pass")
policies := *hd.Policies
require.Len(t, policies, 2)
require.True(t, reflect.DeepEqual(gpResp.Policy.PolicyData, policies[0].PolicyData))
require.Equal(t, policies[0].Response, "pass")
require.True(t, reflect.DeepEqual(tpResp.Policy.PolicyData, hd.Policies[1].PolicyData))
require.Equal(t, hd.Policies[1].Response, "") // policy didn't "run"
require.True(t, reflect.DeepEqual(tpResp.Policy.PolicyData, policies[1].PolicyData))
require.Equal(t, policies[1].Response, "") // policy didn't "run"
// Try to create a global policy with an existing name.
s.DoJSON("POST", "/api/latest/fleet/policies", gpParams, http.StatusConflict, &gpResp)
@ -4522,11 +4523,13 @@ func (s *integrationTestSuite) TestDeviceAuthenticatedEndpoints() {
require.Equal(t, hosts[0].ID, getHostResp.Host.ID)
require.False(t, getHostResp.Host.RefetchRequested)
require.Equal(t, "http://example.com/logo", getHostResp.OrgLogoURL)
require.Nil(t, getHostResp.Host.Policies)
hostDevResp := getHostResp.Host
// make request for same host on the host details API endpoint, responses should match
// make request for same host on the host details API endpoint, responses should match, except for policies
getHostResp = getDeviceHostResponse{}
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/hosts/%d", hosts[0].ID), nil, http.StatusOK, &getHostResp)
getHostResp.Host.Policies = nil
require.Equal(t, hostDevResp, getHostResp.Host)
// request a refetch for that valid host

View File

@ -953,6 +953,12 @@ func (s *integrationEnterpriseTestSuite) TestExternalIntegrationsTeamConfig() {
func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() {
t := s.T()
ac, err := s.ds.AppConfig(context.Background())
require.NoError(t, err)
ac.OrgInfo.OrgLogoURL = "http://example.com/logo"
err = s.ds.SaveAppConfig(context.Background(), ac)
require.NoError(t, err)
team, err := s.ds.NewTeam(context.Background(), &fleet.Team{
ID: 51,
Name: "team1-policies",
@ -975,7 +981,7 @@ func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() {
err = s.ds.AddHostsToTeam(context.Background(), &team.ID, []uint{host.ID})
require.NoError(t, err)
// create an auth token for hosts[0]
// create an auth token for host
token := "much_valid"
mysql.ExecAdhocSQL(t, s.ds, func(db sqlx.ExtContext) error {
_, err := db.ExecContext(context.Background(), `INSERT INTO host_device_auth (host_id, token) VALUES (?, ?)`, host.ID, token)
@ -999,7 +1005,7 @@ func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() {
s.DoJSON("POST", "/api/latest/fleet/policies", gpParams, http.StatusOK, &gpResp)
require.NotNil(t, gpResp.Policy)
// add a policy to team 1
// add a policy to team
oldToken := s.token
t.Cleanup(func() {
s.token = oldToken
@ -1039,10 +1045,22 @@ func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() {
res := s.DoRawNoAuth("GET", "/api/latest/fleet/device/invalid_token/policies", nil, http.StatusUnauthorized)
res.Body.Close()
// GET `/api/_version_/fleet/device/{token}/policies`
listDevicePoliciesResp := listDevicePoliciesResponse{}
res = s.DoRawNoAuth("GET", "/api/latest/fleet/device/"+token+"/policies", nil, http.StatusOK)
json.NewDecoder(res.Body).Decode(&listDevicePoliciesResp)
res.Body.Close()
require.Len(t, listDevicePoliciesResp.Policies, 2)
require.NoError(t, listDevicePoliciesResp.Err)
// GET `/api/_version_/fleet/device/{token}`
getDeviceHostResp := getDeviceHostResponse{}
res = s.DoRawNoAuth("GET", "/api/latest/fleet/device/"+token, nil, http.StatusOK)
json.NewDecoder(res.Body).Decode(&getDeviceHostResp)
res.Body.Close()
require.NoError(t, getDeviceHostResp.Err)
require.Equal(t, host.ID, getDeviceHostResp.Host.ID)
require.False(t, getDeviceHostResp.Host.RefetchRequested)
require.Equal(t, "http://example.com/logo", getDeviceHostResp.OrgLogoURL)
require.Len(t, *getDeviceHostResp.Host.Policies, 2)
}