From 6939af045d7b9ffaf22d9a5b3c26b5f0d722c225 Mon Sep 17 00:00:00 2001 From: Juan Fernandez Date: Wed, 12 Oct 2022 16:13:43 -0400 Subject: [PATCH] Fleet desktop should use lightweight endpoint for getting failing policies count (#8159) Fleet desktop should use lightweight endpoint for getting failing policies count --- ...e-6946-fleet-desktop-uses-minimal-endpoint | 2 + orbit/cmd/desktop/desktop.go | 21 +++----- server/service/device_client.go | 53 ++++++++++++++----- server/service/device_client_test.go | 19 ++++--- server/service/devices.go | 8 +-- server/service/integration_enterprise_test.go | 2 +- 6 files changed, 63 insertions(+), 42 deletions(-) create mode 100644 changes/feature-6946-fleet-desktop-uses-minimal-endpoint diff --git a/changes/feature-6946-fleet-desktop-uses-minimal-endpoint b/changes/feature-6946-fleet-desktop-uses-minimal-endpoint new file mode 100644 index 000000000..7ccf6438b --- /dev/null +++ b/changes/feature-6946-fleet-desktop-uses-minimal-endpoint @@ -0,0 +1,2 @@ +- Updated Fleet Desktop to use the new endpoint introduced in + https://github.com/fleetdm/fleet/issues/7084 diff --git a/orbit/cmd/desktop/desktop.go b/orbit/cmd/desktop/desktop.go index 22d065724..b4c668853 100644 --- a/orbit/cmd/desktop/desktop.go +++ b/orbit/cmd/desktop/desktop.go @@ -120,7 +120,7 @@ func main() { for { refetchToken() - _, err := client.ListDevicePolicies(tokenReader.GetCached()) + _, err := client.NumberOfFailingPolicies(tokenReader.GetCached()) if err == nil || errors.Is(err, service.ErrMissingLicense) { log.Debug().Msg("enabling tray items") @@ -173,7 +173,7 @@ func main() { defer tic.Stop() for { - policies, err := client.ListDevicePolicies(tokenReader.GetCached()) + failingPolicies, err := client.NumberOfFailingPolicies(tokenReader.GetCached()) switch { case err == nil: // OK @@ -185,28 +185,21 @@ func main() { <-checkToken() continue default: - log.Error().Err(err).Msg("get device URL") + log.Error().Err(err).Msg("get failing policies") continue } - failedPolicyCount := 0 - for _, policy := range policies { - if policy.Response != "pass" { - failedPolicyCount++ - } - } - - if failedPolicyCount > 0 { + if failingPolicies > 0 { if runtime.GOOS == "windows" { // Windows (or maybe just the systray library?) doesn't support color emoji // in the system tray menu, so we use text as an alternative. - if failedPolicyCount == 1 { + if failingPolicies == 1 { myDeviceItem.SetTitle("My device (1 issue)") } else { - myDeviceItem.SetTitle(fmt.Sprintf("My device (%d issues)", failedPolicyCount)) + myDeviceItem.SetTitle(fmt.Sprintf("My device (%d issues)", failingPolicies)) } } else { - myDeviceItem.SetTitle(fmt.Sprintf("🔴 My device (%d)", failedPolicyCount)) + myDeviceItem.SetTitle(fmt.Sprintf("🔴 My device (%d)", failingPolicies)) } } else { if runtime.GOOS == "windows" { diff --git a/server/service/device_client.go b/server/service/device_client.go index 58d65c403..e18208186 100644 --- a/server/service/device_client.go +++ b/server/service/device_client.go @@ -49,17 +49,6 @@ func (dc *DeviceClient) request(verb string, path string, query string, response return dc.parseResponse(verb, path, response, responseDest) } -// ListDevicePolicies fetches all policies for the device with the provided token -func (dc *DeviceClient) ListDevicePolicies(token string) ([]*fleet.HostPolicy, error) { - verb, path := "GET", "/api/latest/fleet/device/"+token+"/policies" - var responseBody listDevicePoliciesResponse - err := dc.request(verb, path, "", &responseBody) - if err != nil { - return nil, err - } - return responseBody.Policies, nil -} - // TransparencyURL returns an URL that the server will use to redirect to the // transparency URL configured by the user func (dc *DeviceClient) TransparencyURL(token string) string { @@ -74,8 +63,8 @@ func (dc *DeviceClient) DeviceURL(token string) string { // CheckToken checks if a token is valid by making an authenticated request to // the server func (dc *DeviceClient) CheckToken(token string) error { - verb, path := "GET", "/api/latest/fleet/device/"+token+"/policies" - return dc.request(verb, path, "", &FleetDesktopResponse{}) + _, err := dc.NumberOfFailingPolicies(token) + return err } // Ping sends a ping to the server using the device/ping endpoint @@ -91,3 +80,41 @@ func (dc *DeviceClient) Ping() error { return err } + +func (dc *DeviceClient) getListDevicePolicies(token string) ([]*fleet.HostPolicy, error) { + verb, path := "GET", "/api/latest/fleet/device/"+token+"/policies" + var responseBody listDevicePoliciesResponse + err := dc.request(verb, path, "", &responseBody) + return responseBody.Policies, err +} + +func (dc *DeviceClient) getMinDesktopPayload(token string) (fleetDesktopResponse, error) { + verb, path := "GET", "/api/latest/fleet/device/"+token+"/desktop" + var r fleetDesktopResponse + err := dc.request(verb, path, "", &r) + return r, err +} + +func (dc *DeviceClient) NumberOfFailingPolicies(token string) (uint, error) { + r, err := dc.getMinDesktopPayload(token) + if err == nil { + return uintValueOrZero(r.FailingPolicies), nil + } + + if errors.Is(err, notFoundErr{}) { + policies, err := dc.getListDevicePolicies(token) + if err != nil { + return 0, err + } + + var failingPolicies uint + for _, policy := range policies { + if policy.Response != "pass" { + failingPolicies++ + } + } + return failingPolicies, nil + } + + return 0, err +} diff --git a/server/service/device_client_test.go b/server/service/device_client_test.go index 78cfb86bf..ee702b90d 100644 --- a/server/service/device_client_test.go +++ b/server/service/device_client_test.go @@ -38,24 +38,23 @@ func TestDeviceClientGetDesktopPayload(t *testing.T) { t.Run("with wrong license", func(t *testing.T) { mockRequestDoer.statusCode = http.StatusPaymentRequired - _, err = client.ListDevicePolicies(token) + _, err = client.NumberOfFailingPolicies(token) require.ErrorIs(t, err, ErrMissingLicense) }) - t.Run("with empty policies", func(t *testing.T) { + t.Run("with no failing policies", func(t *testing.T) { mockRequestDoer.statusCode = http.StatusOK - mockRequestDoer.resBody = `{"policies": []}` - policies, err := client.ListDevicePolicies(token) + mockRequestDoer.resBody = `{}` + result, err := client.NumberOfFailingPolicies(token) require.NoError(t, err) - require.Len(t, policies, 0) + require.Equal(t, uint(0), result) }) - t.Run("with policies", func(t *testing.T) { + t.Run("with failing policies", func(t *testing.T) { mockRequestDoer.statusCode = http.StatusOK - mockRequestDoer.resBody = `{"policies": [{"id": 1}]}` - policies, err := client.ListDevicePolicies(token) + mockRequestDoer.resBody = `{"failing_policies_count": 1}` + result, err := client.NumberOfFailingPolicies(token) require.NoError(t, err) - require.Len(t, policies, 1) - require.Equal(t, uint(1), policies[0].ID) + require.Equal(t, uint(1), result) }) } diff --git a/server/service/devices.go b/server/service/devices.go index 6d0fba68f..d1cc42964 100644 --- a/server/service/devices.go +++ b/server/service/devices.go @@ -41,7 +41,7 @@ func (svc *Service) DisableAuthForPing(ctx context.Context) { // Fleet Desktop endpoints ///////////////////////////////////////////////////////////////////////////////// -type FleetDesktopResponse struct { +type fleetDesktopResponse struct { Err error `json:"error,omitempty"` FailingPolicies *uint `json:"failing_policies_count,omitempty"` } @@ -61,15 +61,15 @@ func getFleetDesktopEndpoint(ctx context.Context, request interface{}, svc fleet if !ok { err := ctxerr.Wrap(ctx, fleet.NewAuthRequiredError("internal error: missing host from request context")) - return FleetDesktopResponse{Err: err}, nil + return fleetDesktopResponse{Err: err}, nil } r, err := svc.FailingPoliciesCount(ctx, host) if err != nil { - return FleetDesktopResponse{Err: err}, nil + return fleetDesktopResponse{Err: err}, nil } - return FleetDesktopResponse{FailingPolicies: &r}, nil + return fleetDesktopResponse{FailingPolicies: &r}, nil } ///////////////////////////////////////////////////////////////////////////////// diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index b76b65822..ddc930069 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -1278,7 +1278,7 @@ func (s *integrationEnterpriseTestSuite) TestListDevicePolicies() { require.Len(t, *getDeviceHostResp.Host.Policies, 2) // GET `/api/_version_/fleet/device/{token}/desktop` - getDesktopResp := FleetDesktopResponse{} + getDesktopResp := fleetDesktopResponse{} res = s.DoRawNoAuth("GET", "/api/latest/fleet/device/"+token+"/desktop", nil, http.StatusOK) json.NewDecoder(res.Body).Decode(&getDesktopResp) res.Body.Close()