Allow filtering hosts by software_version_id and software_title_id. (#15433)

This commit is contained in:
Martin Angers 2023-12-06 14:59:00 -05:00 committed by GitHub
parent 5fd799ff2e
commit 6b128dd455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 360 additions and 37 deletions

View File

@ -0,0 +1 @@
* Added ability to filter hosts by `software_version_id` and `software_title_id` for the "list hosts", "count hosts" and "get hosts report in CSV" endpoints.

View File

@ -943,9 +943,20 @@ func (ds *Datastore) applyHostFilters(
}
softwareFilter := "TRUE"
if opt.SoftwareIDFilter != nil {
var softwareIDFilter *uint
if opt.SoftwareVersionIDFilter != nil {
softwareIDFilter = opt.SoftwareVersionIDFilter
} else if opt.SoftwareIDFilter != nil {
softwareIDFilter = opt.SoftwareIDFilter
}
if softwareIDFilter != nil {
softwareFilter = "EXISTS (SELECT 1 FROM host_software hs WHERE hs.host_id = h.id AND hs.software_id = ?)"
params = append(params, opt.SoftwareIDFilter)
params = append(params, *softwareIDFilter)
} else if opt.SoftwareTitleIDFilter != nil {
// software (version) ID filter is mutually exclusive with software title ID
// so we're reusing the same filter to avoid adding unnecessary conditions.
softwareFilter = "EXISTS (SELECT 1 FROM host_software hs INNER JOIN software sw ON hs.software_id = sw.id WHERE hs.host_id = h.id AND sw.title_id = ?)"
params = append(params, *opt.SoftwareTitleIDFilter)
}
failingPoliciesJoin := ""

View File

@ -2767,16 +2767,63 @@ func testHostsListBySoftware(t *testing.T, ds *Datastore) {
}
host1 := hosts[0]
host2 := hosts[1]
host3 := hosts[2]
_, err := ds.UpdateHostSoftware(context.Background(), host1.ID, software)
require.NoError(t, err)
_, err = ds.UpdateHostSoftware(context.Background(), host2.ID, software)
require.NoError(t, err)
// host 3 only has foo v0.0.3
_, err = ds.UpdateHostSoftware(context.Background(), host3.ID, software[1:2])
require.NoError(t, err)
// reconcile software, will sync software titles
err = ds.ReconcileSoftwareTitles(context.Background())
require.NoError(t, err)
var fooV002ID uint
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(context.Background(), q, &fooV002ID,
"SELECT id FROM software WHERE name = ? AND source = ? AND version = ?", "foo", "chrome_extensions", "0.0.2")
})
var fooTitleID uint
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(context.Background(), q, &fooTitleID,
"SELECT id FROM software_titles WHERE name = ? AND source = ?", "foo", "chrome_extensions")
})
require.NoError(t, ds.LoadHostSoftware(context.Background(), host1, false))
require.NoError(t, ds.LoadHostSoftware(context.Background(), host2, false))
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareIDFilter: &host1.Software[0].ID}, 2)
// software_id is foo v0.0.2
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareIDFilter: &fooV002ID}, 2)
require.Len(t, hosts, 2)
got := []uint{hosts[0].ID, hosts[1].ID}
require.ElementsMatch(t, []uint{host1.ID, host2.ID}, got)
// software_version_id is foo v0.0.2 (works exacty the same)
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareVersionIDFilter: &fooV002ID}, 2)
require.Len(t, hosts, 2)
got = []uint{hosts[0].ID, hosts[1].ID}
require.ElementsMatch(t, []uint{host1.ID, host2.ID}, got)
// unknown software_id
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareIDFilter: ptr.Uint(fooV002ID + 100)}, 0)
require.Len(t, hosts, 0)
// unknown software_version_id
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareVersionIDFilter: ptr.Uint(fooV002ID + 100)}, 0)
require.Len(t, hosts, 0)
// software_title_id is foo (any version)
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareTitleIDFilter: &fooTitleID}, 3)
require.Len(t, hosts, 3)
got = []uint{hosts[0].ID, hosts[1].ID, hosts[2].ID}
require.ElementsMatch(t, []uint{host1.ID, host2.ID, host3.ID}, got)
// unknown software_title_id
hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareTitleIDFilter: ptr.Uint(fooTitleID + 100)}, 0)
require.Len(t, hosts, 0)
}
func testHostsListBySoftwareChangedAt(t *testing.T, ds *Datastore) {

View File

@ -126,7 +126,17 @@ type HostListOptions struct {
PolicyIDFilter *uint
PolicyResponseFilter *bool
// Deprecated: SoftwareIDFilter is deprecated as of Fleet 4.42. It is
// maintained for backwards compatibility. Use SoftwareVersionIDFilter
// instead.
SoftwareIDFilter *uint
// SoftwareVersionIDFilter filters the hosts by the software version ID that
// they use. This identifies a specific version of a "software title".
SoftwareVersionIDFilter *uint
// SoftwareTitleIDFilter filers the hosts by the software title ID that they
// use. This identifies a "software title" independent of the specific
// version.
SoftwareTitleIDFilter *uint
OSIDFilter *uint
OSNameFilter *string
@ -179,6 +189,8 @@ func (h HostListOptions) Empty() bool {
h.PolicyIDFilter == nil &&
h.PolicyResponseFilter == nil &&
h.SoftwareIDFilter == nil &&
h.SoftwareVersionIDFilter == nil &&
h.SoftwareTitleIDFilter == nil &&
h.OSIDFilter == nil &&
h.OSNameFilter == nil &&
h.OSVersionFilter == nil &&

View File

@ -54,8 +54,17 @@ type listHostsRequest struct {
}
type listHostsResponse struct {
Hosts []fleet.HostResponse `json:"hosts"`
Software *fleet.Software `json:"software,omitempty"`
Hosts []fleet.HostResponse `json:"hosts"`
// Software is populated with the software version corresponding to the
// software_version_id (or software_id) filter if one is provided with the
// request (and it exists in the database). It is nil otherwise and absent of
// the JSON response payload.
Software *fleet.Software `json:"software,omitempty"`
// SoftwareTitle is populated with the title corresponding to the
// software_title_id filter if one is provided with the request (and it
// exists in the database). It is nil otherwise and absent of the JSON
// response payload.
SoftwareTitle *fleet.SoftwareTitle `json:"software_title,omitempty"`
// MDMSolution is populated with the MDM solution corresponding to the mdm_id
// filter if one is provided with the request (and it exists in the
// database). It is nil otherwise and absent of the JSON response payload.
@ -75,9 +84,24 @@ func listHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Servi
req := request.(*listHostsRequest)
var software *fleet.Software
if req.Opts.SoftwareIDFilter != nil {
if req.Opts.SoftwareVersionIDFilter != nil || req.Opts.SoftwareIDFilter != nil {
var err error
software, err = svc.SoftwareByID(ctx, *req.Opts.SoftwareIDFilter, false)
id := req.Opts.SoftwareVersionIDFilter
if id == nil {
id = req.Opts.SoftwareIDFilter
}
software, err = svc.SoftwareByID(ctx, *id, false)
if err != nil {
return listHostsResponse{Err: err}, nil
}
}
var softwareTitle *fleet.SoftwareTitle
if req.Opts.SoftwareTitleIDFilter != nil {
var err error
softwareTitle, err = svc.SoftwareTitleByID(ctx, *req.Opts.SoftwareTitleIDFilter)
if err != nil {
return listHostsResponse{Err: err}, nil
}
@ -112,10 +136,11 @@ func listHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Servi
hostResponses[i] = *h
}
return listHostsResponse{
Hosts: hostResponses,
Software: software,
MDMSolution: mdmSolution,
MunkiIssue: munkiIssue,
Hosts: hostResponses,
Software: software,
SoftwareTitle: softwareTitle,
MDMSolution: mdmSolution,
MunkiIssue: munkiIssue,
}, nil
}

View File

@ -1363,20 +1363,119 @@ func (s *integrationTestSuite) TestListHosts() {
require.Len(t, resp.Hosts, len(hosts)-2)
time.Sleep(1 * time.Second)
host := hosts[2]
// create some software for various hosts
host2 := hosts[2]
software := []fleet.Software{
{Name: "foo", Version: "0.0.1", Source: "chrome_extensions"},
}
_, err := s.ds.UpdateHostSoftware(context.Background(), host.ID, software)
_, err := s.ds.UpdateHostSoftware(context.Background(), host2.ID, software)
require.NoError(t, err)
require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host, false))
require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host2, false))
host1 := hosts[1]
software = []fleet.Software{
{Name: "foo", Version: "0.0.2", Source: "chrome_extensions"},
{Name: "bar", Version: "0.1.0", Source: "application"},
}
_, err = s.ds.UpdateHostSoftware(context.Background(), host1.ID, software)
require.NoError(t, err)
require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host1, false))
host0 := hosts[0]
software = []fleet.Software{
{Name: "foo", Version: "0.0.2", Source: "chrome_extensions"},
{Name: "bar", Version: "0.2.0", Source: "not_application"},
}
_, err = s.ds.UpdateHostSoftware(context.Background(), host0.ID, software)
require.NoError(t, err)
require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host0, false))
err = s.ds.SyncHostsSoftware(context.Background(), time.Now())
require.NoError(t, err)
err = s.ds.ReconcileSoftwareTitles(context.Background())
require.NoError(t, err)
var fooV1ID, fooV2ID, barAppTitleID, fooTitleID uint
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
err := sqlx.GetContext(context.Background(), q, &fooV1ID,
`SELECT id FROM software WHERE name = ? AND source = ? AND version = ?`, "foo", "chrome_extensions", "0.0.1")
if err != nil {
return err
}
err = sqlx.GetContext(context.Background(), q, &fooV2ID,
`SELECT id FROM software WHERE name = ? AND source = ? AND version = ?`, "foo", "chrome_extensions", "0.0.2")
if err != nil {
return err
}
err = sqlx.GetContext(context.Background(), q, &barAppTitleID,
`SELECT id FROM software_titles WHERE name = ? AND source = ?`, "bar", "application")
if err != nil {
return err
}
err = sqlx.GetContext(context.Background(), q, &fooTitleID,
`SELECT id FROM software_titles WHERE name = ? AND source = ?`, "foo", "chrome_extensions")
if err != nil {
return err
}
return nil
})
// foo v0.0.1 is only installed on host2
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(host.Software[0].ID))
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(fooV1ID))
require.Len(t, resp.Hosts, 1)
assert.Equal(t, host.ID, resp.Hosts[0].ID)
assert.Equal(t, host2.ID, resp.Hosts[0].ID)
assert.Equal(t, "foo", resp.Software.Name)
assert.Greater(t, resp.Hosts[0].SoftwareUpdatedAt, resp.Hosts[0].CreatedAt)
assert.Nil(t, resp.SoftwareTitle)
var countResp countHostsResponse
s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusOK, &countResp, "software_id", fmt.Sprint(fooV1ID))
require.Equal(t, 1, countResp.Count)
// foo v0.0.2 is installed on hosts 0 and 1
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_version_id", fmt.Sprint(fooV2ID))
require.Len(t, resp.Hosts, 2)
require.ElementsMatch(t, []uint{host0.ID, host1.ID}, []uint{resp.Hosts[0].ID, resp.Hosts[1].ID})
s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusOK, &countResp, "software_version_id", fmt.Sprint(fooV2ID))
require.Equal(t, 2, countResp.Count)
// bar/application title is only on host1
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_title_id", fmt.Sprint(barAppTitleID))
require.Len(t, resp.Hosts, 1)
require.ElementsMatch(t, []uint{host1.ID}, []uint{resp.Hosts[0].ID})
assert.Equal(t, "bar", resp.SoftwareTitle.Name)
assert.Equal(t, "application", resp.SoftwareTitle.Source)
assert.Equal(t, uint(1), resp.SoftwareTitle.HostsCount)
require.Len(t, resp.SoftwareTitle.Versions, 1)
assert.Equal(t, "0.1.0", resp.SoftwareTitle.Versions[0].Version)
assert.Nil(t, resp.Software)
s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusOK, &countResp, "software_title_id", fmt.Sprint(barAppTitleID))
require.Equal(t, 1, countResp.Count)
// foo title is on all 3 hosts
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_title_id", fmt.Sprint(fooTitleID))
require.Len(t, resp.Hosts, 3)
require.ElementsMatch(t, []uint{host0.ID, host1.ID, host2.ID}, []uint{resp.Hosts[0].ID, resp.Hosts[1].ID, resp.Hosts[2].ID})
s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusOK, &countResp, "software_title_id", fmt.Sprint(fooTitleID))
require.Equal(t, 3, countResp.Count)
// verify invalid combinations of software filters
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "software_title_id", fmt.Sprint(fooTitleID), "software_id", fmt.Sprint(fooV1ID))
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "software_title_id", fmt.Sprint(fooTitleID), "software_version_id", fmt.Sprint(fooV1ID))
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "software_id", fmt.Sprint(fooV1ID), "software_version_id", fmt.Sprint(fooV1ID))
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "software_id", fmt.Sprint(fooV1ID), "software_version_id", fmt.Sprint(fooV1ID), "software_title_id", fmt.Sprint(fooTitleID))
s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusBadRequest, &countResp, "software_title_id", fmt.Sprint(fooTitleID), "software_id", fmt.Sprint(fooV1ID))
s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusBadRequest, &countResp, "software_title_id", fmt.Sprint(fooTitleID), "software_version_id", fmt.Sprint(fooV1ID))
s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusBadRequest, &countResp, "software_id", fmt.Sprint(fooV1ID), "software_version_id", fmt.Sprint(fooV1ID))
s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusBadRequest, &countResp, "software_id", fmt.Sprint(fooV1ID), "software_version_id", fmt.Sprint(fooV1ID), "software_title_id", fmt.Sprint(fooTitleID))
user1 := test.NewUser(t, s.ds, "Alice", "alice@example.com", true)
q := test.NewQuery(t, s.ds, nil, "query1", "select 1", 0, true)
@ -1386,16 +1485,16 @@ func (s *integrationTestSuite) TestListHosts() {
})
require.NoError(t, err)
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false))
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(host.Software[0].ID))
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(fooV1ID))
require.Len(t, resp.Hosts, 1)
assert.Equal(t, 1, resp.Hosts[0].HostIssues.FailingPoliciesCount)
assert.Equal(t, 1, resp.Hosts[0].HostIssues.TotalIssuesCount)
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(host.Software[0].ID), "disable_failing_policies", "true")
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_version_id", fmt.Sprint(fooV1ID), "disable_failing_policies", "true")
require.Len(t, resp.Hosts, 1)
assert.Equal(t, 0, resp.Hosts[0].HostIssues.FailingPoliciesCount)
assert.Equal(t, 0, resp.Hosts[0].HostIssues.TotalIssuesCount)
@ -1422,7 +1521,7 @@ func (s *integrationTestSuite) TestListHosts() {
assert.Nil(t, resp.MunkiIssue)
// set MDM information on a host
require.NoError(t, s.ds.SetOrUpdateMDMData(context.Background(), host.ID, false, true, "https://simplemdm.com", false, fleet.WellKnownMDMSimpleMDM))
require.NoError(t, s.ds.SetOrUpdateMDMData(context.Background(), host2.ID, false, true, "https://simplemdm.com", false, fleet.WellKnownMDMSimpleMDM))
var mdmID uint
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(context.Background(), q, &mdmID,
@ -1500,13 +1599,17 @@ func (s *integrationTestSuite) TestListHosts() {
// Filter by inexistent software.
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusNotFound, &resp, "software_id", fmt.Sprint(9999))
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusNotFound, &resp, "software_version_id", fmt.Sprint(9999))
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusNotFound, &resp, "software_title_id", fmt.Sprint(9999))
// Filter by non-existent team.
resp = listHostsResponse{}
s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "team_id", fmt.Sprint(9999))
// set munki information on a host
require.NoError(t, s.ds.SetOrUpdateMunkiInfo(context.Background(), host.ID, "1.2.3", []string{"err"}, []string{"warn"}))
require.NoError(t, s.ds.SetOrUpdateMunkiInfo(context.Background(), host2.ID, "1.2.3", []string{"err"}, []string{"warn"}))
var errMunkiID uint
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(context.Background(), q, &errMunkiID,
@ -1537,7 +1640,7 @@ func (s *integrationTestSuite) TestListHosts() {
// set operating system information on a host
testOS := fleet.OperatingSystem{Name: "fooOS", Version: "4.2", Arch: "64bit", KernelVersion: "13.37", Platform: "bar"}
require.NoError(t, s.ds.UpdateHostOperatingSystem(context.Background(), host.ID, testOS))
require.NoError(t, s.ds.UpdateHostOperatingSystem(context.Background(), host2.ID, testOS))
var osID uint
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(context.Background(), q, &osID,
@ -6567,6 +6670,32 @@ func (s *integrationTestSuite) TestHostsReportDownload() {
require.NoError(t, s.ds.SetOrUpdateHostDisksSpace(ctx, hosts[0].ID, 1.0, 2.0))
require.NoError(t, s.ds.SetOrUpdateHostDisksSpace(ctx, hosts[1].ID, 3.0, 4.0))
// create software for host [0]
software := []fleet.Software{
{Name: "foo", Version: "0.0.1", Source: "chrome_extensions"},
}
_, err = s.ds.UpdateHostSoftware(ctx, hosts[0].ID, software)
require.NoError(t, err)
require.NoError(t, s.ds.LoadHostSoftware(ctx, hosts[0], false))
err = s.ds.ReconcileSoftwareTitles(ctx)
require.NoError(t, err)
var fooV1ID, fooTitleID uint
mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error {
err := sqlx.GetContext(context.Background(), q, &fooV1ID,
`SELECT id FROM software WHERE name = ? AND source = ? AND version = ?`, "foo", "chrome_extensions", "0.0.1")
if err != nil {
return err
}
err = sqlx.GetContext(context.Background(), q, &fooTitleID,
`SELECT id FROM software_titles WHERE name = ? AND source = ?`, "foo", "chrome_extensions")
if err != nil {
return err
}
return nil
})
res := s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusUnsupportedMediaType, "format", "gzip")
var errs validationErrResp
require.NoError(t, json.NewDecoder(res.Body).Decode(&errs))
@ -6654,6 +6783,22 @@ func (s *integrationTestSuite) TestHostsReportDownload() {
require.Len(t, rows, 2) // headers + member host
require.Contains(t, rows[1], hosts[2].Hostname)
// with a software version id
res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv", "columns", "hostname", "software_version_id", fmt.Sprint(fooV1ID))
rows, err = csv.NewReader(res.Body).ReadAll()
res.Body.Close()
require.NoError(t, err)
require.Len(t, rows, 2) // headers + member host
require.Contains(t, rows[1], hosts[0].Hostname)
// with a software title id
res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv", "columns", "hostname", "software_title_id", fmt.Sprint(fooTitleID))
rows, err = csv.NewReader(res.Body).ReadAll()
res.Body.Close()
require.NoError(t, err)
require.Len(t, rows, 2) // headers + member host
require.Contains(t, rows[1], hosts[0].Hostname)
// valid format but an invalid column is provided
res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "format", "csv", "columns", "memory,hostname,status,nosuchcolumn")
require.NoError(t, json.NewDecoder(res.Body).Decode(&errs))
@ -6675,6 +6820,12 @@ func (s *integrationTestSuite) TestHostsReportDownload() {
require.Equal(t, []string{"0", "TestIntegrations/TestHostsReportDownloadfoo.local1"}, rows[2][:2])
require.Len(t, rows[3], 3)
require.Equal(t, []string{"0", "TestIntegrations/TestHostsReportDownloadfoo.local0"}, rows[3][:2])
// invalid combinations of software filters
s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "software_title_id", "123", "software_id", "456")
s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "software_title_id", "123", "software_version_id", "456")
s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "software_id", "123", "software_version_id", "456")
s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "software_id", "123", "software_version_id", "456", "software_title_id", "789")
}
func (s *integrationTestSuite) TestSSODisabled() {

View File

@ -217,7 +217,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
teamID := r.URL.Query().Get("team_id")
if teamID != "" {
id, err := strconv.Atoi(teamID)
id, err := strconv.ParseUint(teamID, 10, 32)
if err != nil {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid team_id: %s", teamID)))
}
@ -227,7 +227,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
policyID := r.URL.Query().Get("policy_id")
if policyID != "" {
id, err := strconv.Atoi(policyID)
id, err := strconv.ParseUint(policyID, 10, 32)
if err != nil {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid policy_id: %s", policyID)))
}
@ -266,7 +266,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
softwareID := r.URL.Query().Get("software_id")
if softwareID != "" {
id, err := strconv.Atoi(softwareID)
id, err := strconv.ParseUint(softwareID, 10, 64)
if err != nil {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid software_id: %s", softwareID)))
}
@ -274,9 +274,29 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
hopt.SoftwareIDFilter = &sid
}
softwareVersionID := r.URL.Query().Get("software_version_id")
if softwareVersionID != "" {
id, err := strconv.ParseUint(softwareVersionID, 10, 64)
if err != nil {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid software_version_id: %s", softwareVersionID)))
}
sid := uint(id)
hopt.SoftwareVersionIDFilter = &sid
}
softwareTitleID := r.URL.Query().Get("software_title_id")
if softwareTitleID != "" {
id, err := strconv.ParseUint(softwareTitleID, 10, 32)
if err != nil {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid software_title_id: %s", softwareTitleID)))
}
sid := uint(id)
hopt.SoftwareTitleIDFilter = &sid
}
osID := r.URL.Query().Get("os_id")
if osID != "" {
id, err := strconv.Atoi(osID)
id, err := strconv.ParseUint(osID, 10, 32)
if err != nil {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid os_id: %s", osID)))
}
@ -336,7 +356,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
mdmID := r.URL.Query().Get("mdm_id")
if mdmID != "" {
id, err := strconv.Atoi(mdmID)
id, err := strconv.ParseUint(mdmID, 10, 32)
if err != nil {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid mdm_id: %s", mdmID)))
}
@ -438,7 +458,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
munkiIssueID := r.URL.Query().Get("munki_issue_id")
if munkiIssueID != "" {
id, err := strconv.Atoi(munkiIssueID)
id, err := strconv.ParseUint(munkiIssueID, 10, 32)
if err != nil {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid munki_issue_id: %s", munkiIssueID)))
}
@ -464,6 +484,21 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error)
hopt.LowDiskSpaceFilter = &v
}
// cannot combine software_id, software_version_id, and software_title_id
var softwareErrorLabel []string
if hopt.SoftwareIDFilter != nil {
softwareErrorLabel = append(softwareErrorLabel, "software_id")
}
if hopt.SoftwareVersionIDFilter != nil {
softwareErrorLabel = append(softwareErrorLabel, "software_version_id")
}
if hopt.SoftwareTitleIDFilter != nil {
softwareErrorLabel = append(softwareErrorLabel, "software_title_id")
}
if len(softwareErrorLabel) > 1 {
return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid parameters. The combination of %s is not allowed.", strings.Join(softwareErrorLabel, " and "))))
}
return hopt, nil
}

View File

@ -1,12 +1,14 @@
package service
import (
"github.com/fleetdm/fleet/v4/server/ptr"
"fmt"
"net/http"
"net/url"
"strings"
"testing"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -291,6 +293,48 @@ func TestHostListOptionsFromRequest(t *testing.T) {
url: "/foo?os_name=foo",
errorMessage: "Invalid os_version",
},
"negative software_id": {
url: "/foo?software_id=-10",
errorMessage: "Invalid software_id",
},
"negative software_version_id": {
url: "/foo?software_version_id=-10",
errorMessage: "Invalid software_version_id",
},
"negative software_title_id": {
url: "/foo?software_title_id=-10",
errorMessage: "Invalid software_title_id",
},
"software_title_id too big": {
url: "/foo?software_title_id=" + fmt.Sprint(1<<33),
errorMessage: "Invalid software_title_id",
},
"software_version_id can be > 32bits": {
url: "/foo?software_version_id=" + fmt.Sprint(1<<33),
hostListOptions: fleet.HostListOptions{
SoftwareVersionIDFilter: ptr.Uint(1 << 33),
},
},
"good software_version_id": {
url: "/foo?software_version_id=1",
hostListOptions: fleet.HostListOptions{
SoftwareVersionIDFilter: ptr.Uint(1),
},
},
"good software_title_id": {
url: "/foo?software_title_id=1",
hostListOptions: fleet.HostListOptions{
SoftwareTitleIDFilter: ptr.Uint(1),
},
},
"invalid combination software_title_id and software_version_id": {
url: "/foo?software_title_id=1&software_version_id=2",
errorMessage: "The combination of software_version_id and software_title_id is not allowed",
},
"invalid combination software_id and software_version_id": {
url: "/foo?software_id=1&software_version_id=2",
errorMessage: "The combination of software_id and software_version_id is not allowed",
},
}
for name, tt := range hostListOptionsTests {
@ -304,10 +348,7 @@ func TestHostListOptionsFromRequest(t *testing.T) {
assert.NotNil(t, err)
var be *fleet.BadRequestError
require.ErrorAs(t, err, &be)
assert.True(
t, strings.Contains(err.Error(), tt.errorMessage),
"error message '%v' should contain '%v'", err.Error(), tt.errorMessage,
)
require.Contains(t, err.Error(), tt.errorMessage)
return
}
assert.Nil(t, err)