Fleet desktop should use lightweight endpoint for getting failing policies count (#8159)

Fleet desktop should use lightweight endpoint for getting failing policies count
This commit is contained in:
Juan Fernandez 2022-10-12 16:13:43 -04:00 committed by GitHub
parent 91ea7bf244
commit 6939af045d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 42 deletions

View File

@ -0,0 +1,2 @@
- Updated Fleet Desktop to use the new endpoint introduced in
https://github.com/fleetdm/fleet/issues/7084

View File

@ -120,7 +120,7 @@ func main() {
for { for {
refetchToken() refetchToken()
_, err := client.ListDevicePolicies(tokenReader.GetCached()) _, err := client.NumberOfFailingPolicies(tokenReader.GetCached())
if err == nil || errors.Is(err, service.ErrMissingLicense) { if err == nil || errors.Is(err, service.ErrMissingLicense) {
log.Debug().Msg("enabling tray items") log.Debug().Msg("enabling tray items")
@ -173,7 +173,7 @@ func main() {
defer tic.Stop() defer tic.Stop()
for { for {
policies, err := client.ListDevicePolicies(tokenReader.GetCached()) failingPolicies, err := client.NumberOfFailingPolicies(tokenReader.GetCached())
switch { switch {
case err == nil: case err == nil:
// OK // OK
@ -185,28 +185,21 @@ func main() {
<-checkToken() <-checkToken()
continue continue
default: default:
log.Error().Err(err).Msg("get device URL") log.Error().Err(err).Msg("get failing policies")
continue continue
} }
failedPolicyCount := 0 if failingPolicies > 0 {
for _, policy := range policies {
if policy.Response != "pass" {
failedPolicyCount++
}
}
if failedPolicyCount > 0 {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
// Windows (or maybe just the systray library?) doesn't support color emoji // Windows (or maybe just the systray library?) doesn't support color emoji
// in the system tray menu, so we use text as an alternative. // in the system tray menu, so we use text as an alternative.
if failedPolicyCount == 1 { if failingPolicies == 1 {
myDeviceItem.SetTitle("My device (1 issue)") myDeviceItem.SetTitle("My device (1 issue)")
} else { } else {
myDeviceItem.SetTitle(fmt.Sprintf("My device (%d issues)", failedPolicyCount)) myDeviceItem.SetTitle(fmt.Sprintf("My device (%d issues)", failingPolicies))
} }
} else { } else {
myDeviceItem.SetTitle(fmt.Sprintf("🔴 My device (%d)", failedPolicyCount)) myDeviceItem.SetTitle(fmt.Sprintf("🔴 My device (%d)", failingPolicies))
} }
} else { } else {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {

View File

@ -49,17 +49,6 @@ func (dc *DeviceClient) request(verb string, path string, query string, response
return dc.parseResponse(verb, path, response, responseDest) return dc.parseResponse(verb, path, response, responseDest)
} }
// ListDevicePolicies fetches all policies for the device with the provided token
func (dc *DeviceClient) ListDevicePolicies(token string) ([]*fleet.HostPolicy, error) {
verb, path := "GET", "/api/latest/fleet/device/"+token+"/policies"
var responseBody listDevicePoliciesResponse
err := dc.request(verb, path, "", &responseBody)
if err != nil {
return nil, err
}
return responseBody.Policies, nil
}
// TransparencyURL returns an URL that the server will use to redirect to the // TransparencyURL returns an URL that the server will use to redirect to the
// transparency URL configured by the user // transparency URL configured by the user
func (dc *DeviceClient) TransparencyURL(token string) string { func (dc *DeviceClient) TransparencyURL(token string) string {
@ -74,8 +63,8 @@ func (dc *DeviceClient) DeviceURL(token string) string {
// CheckToken checks if a token is valid by making an authenticated request to // CheckToken checks if a token is valid by making an authenticated request to
// the server // the server
func (dc *DeviceClient) CheckToken(token string) error { func (dc *DeviceClient) CheckToken(token string) error {
verb, path := "GET", "/api/latest/fleet/device/"+token+"/policies" _, err := dc.NumberOfFailingPolicies(token)
return dc.request(verb, path, "", &FleetDesktopResponse{}) return err
} }
// Ping sends a ping to the server using the device/ping endpoint // Ping sends a ping to the server using the device/ping endpoint
@ -91,3 +80,41 @@ func (dc *DeviceClient) Ping() error {
return err return err
} }
func (dc *DeviceClient) getListDevicePolicies(token string) ([]*fleet.HostPolicy, error) {
verb, path := "GET", "/api/latest/fleet/device/"+token+"/policies"
var responseBody listDevicePoliciesResponse
err := dc.request(verb, path, "", &responseBody)
return responseBody.Policies, err
}
func (dc *DeviceClient) getMinDesktopPayload(token string) (fleetDesktopResponse, error) {
verb, path := "GET", "/api/latest/fleet/device/"+token+"/desktop"
var r fleetDesktopResponse
err := dc.request(verb, path, "", &r)
return r, err
}
func (dc *DeviceClient) NumberOfFailingPolicies(token string) (uint, error) {
r, err := dc.getMinDesktopPayload(token)
if err == nil {
return uintValueOrZero(r.FailingPolicies), nil
}
if errors.Is(err, notFoundErr{}) {
policies, err := dc.getListDevicePolicies(token)
if err != nil {
return 0, err
}
var failingPolicies uint
for _, policy := range policies {
if policy.Response != "pass" {
failingPolicies++
}
}
return failingPolicies, nil
}
return 0, err
}

View File

@ -38,24 +38,23 @@ func TestDeviceClientGetDesktopPayload(t *testing.T) {
t.Run("with wrong license", func(t *testing.T) { t.Run("with wrong license", func(t *testing.T) {
mockRequestDoer.statusCode = http.StatusPaymentRequired mockRequestDoer.statusCode = http.StatusPaymentRequired
_, err = client.ListDevicePolicies(token) _, err = client.NumberOfFailingPolicies(token)
require.ErrorIs(t, err, ErrMissingLicense) require.ErrorIs(t, err, ErrMissingLicense)
}) })
t.Run("with empty policies", func(t *testing.T) { t.Run("with no failing policies", func(t *testing.T) {
mockRequestDoer.statusCode = http.StatusOK mockRequestDoer.statusCode = http.StatusOK
mockRequestDoer.resBody = `{"policies": []}` mockRequestDoer.resBody = `{}`
policies, err := client.ListDevicePolicies(token) result, err := client.NumberOfFailingPolicies(token)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, policies, 0) require.Equal(t, uint(0), result)
}) })
t.Run("with policies", func(t *testing.T) { t.Run("with failing policies", func(t *testing.T) {
mockRequestDoer.statusCode = http.StatusOK mockRequestDoer.statusCode = http.StatusOK
mockRequestDoer.resBody = `{"policies": [{"id": 1}]}` mockRequestDoer.resBody = `{"failing_policies_count": 1}`
policies, err := client.ListDevicePolicies(token) result, err := client.NumberOfFailingPolicies(token)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, policies, 1) require.Equal(t, uint(1), result)
require.Equal(t, uint(1), policies[0].ID)
}) })
} }

View File

@ -41,7 +41,7 @@ func (svc *Service) DisableAuthForPing(ctx context.Context) {
// Fleet Desktop endpoints // Fleet Desktop endpoints
///////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////
type FleetDesktopResponse struct { type fleetDesktopResponse struct {
Err error `json:"error,omitempty"` Err error `json:"error,omitempty"`
FailingPolicies *uint `json:"failing_policies_count,omitempty"` FailingPolicies *uint `json:"failing_policies_count,omitempty"`
} }
@ -61,15 +61,15 @@ func getFleetDesktopEndpoint(ctx context.Context, request interface{}, svc fleet
if !ok { if !ok {
err := ctxerr.Wrap(ctx, fleet.NewAuthRequiredError("internal error: missing host from request context")) err := ctxerr.Wrap(ctx, fleet.NewAuthRequiredError("internal error: missing host from request context"))
return FleetDesktopResponse{Err: err}, nil return fleetDesktopResponse{Err: err}, nil
} }
r, err := svc.FailingPoliciesCount(ctx, host) r, err := svc.FailingPoliciesCount(ctx, host)
if err != nil { if err != nil {
return FleetDesktopResponse{Err: err}, nil return fleetDesktopResponse{Err: err}, nil
} }
return FleetDesktopResponse{FailingPolicies: &r}, nil return fleetDesktopResponse{FailingPolicies: &r}, nil
} }
///////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////

View File

@ -1278,7 +1278,7 @@ func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() {
require.Len(t, *getDeviceHostResp.Host.Policies, 2) require.Len(t, *getDeviceHostResp.Host.Policies, 2)
// GET `/api/_version_/fleet/device/{token}/desktop` // GET `/api/_version_/fleet/device/{token}/desktop`
getDesktopResp := FleetDesktopResponse{} getDesktopResp := fleetDesktopResponse{}
res = s.DoRawNoAuth("GET", "/api/latest/fleet/device/"+token+"/desktop", nil, http.StatusOK) res = s.DoRawNoAuth("GET", "/api/latest/fleet/device/"+token+"/desktop", nil, http.StatusOK)
json.NewDecoder(res.Body).Decode(&getDesktopResp) json.NewDecoder(res.Body).Decode(&getDesktopResp)
res.Body.Close() res.Body.Close()