Add vulnerable filter for software and also wire up the query search (#2604)

* Add vulnerable filter for software and also wire up the query search

* Add documentation

* Update to use software list options
This commit is contained in:
Tomas Touceda 2021-10-20 18:01:20 -03:00 committed by GitHub
parent 92e3e83a74
commit cbe931c0e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 91 additions and 51 deletions

View File

@ -0,0 +1 @@
* Add vulnerable=true/false/1/0 query parameter to the software listing endpoint and wire up the query search for filtering software by name.

View File

@ -532,8 +532,8 @@ func TestGetSoftawre(t *testing.T) {
var gotTeamID *uint var gotTeamID *uint
ds.ListSoftwareFunc = func(ctx context.Context, teamId *uint, opt fleet.ListOptions) ([]fleet.Software, error) { ds.ListSoftwareFunc = func(ctx context.Context, opt fleet.SoftwareListOptions) ([]fleet.Software, error) {
gotTeamID = teamId gotTeamID = opt.TeamID
return []fleet.Software{foo001, foo002, foo003, bar003}, nil return []fleet.Software{foo001, foo002, foo003, bar003}, nil
} }

View File

@ -6439,6 +6439,7 @@ If the `name` is not already associated with an existing team, this API route cr
| order_direction | string | query | **Requires `order_key`**. The direction of the order given the order key. Options include `asc` and `desc`. Default is `asc`. | | order_direction | string | query | **Requires `order_key`**. The direction of the order given the order key. Options include `asc` and `desc`. Default is `asc`. |
| query | string | query | Search query keywords. Searchable fields include `hostname`, `machine_serial`, `uuid`, and `ipv4`. | | query | string | query | Search query keywords. Searchable fields include `hostname`, `machine_serial`, `uuid`, and `ipv4`. |
| team_id | integer | query | _Available in Fleet Premium_ Filters the users to only include users in the specified team. | | team_id | integer | query | _Available in Fleet Premium_ Filters the users to only include users in the specified team. |
| vulnerable | bool | query | If true or 1, only list software that has detected vulnerabilities |
#### Example #### Example

View File

@ -92,7 +92,7 @@ func nothingChanged(current []fleet.Software, incoming []fleet.Software) bool {
} }
func applyChangesForNewSoftwareDB(ctx context.Context, tx sqlx.ExtContext, host *fleet.Host) error { func applyChangesForNewSoftwareDB(ctx context.Context, tx sqlx.ExtContext, host *fleet.Host) error {
storedCurrentSoftware, err := listSoftwareDB(ctx, tx, &host.ID, nil, fleet.ListOptions{}) storedCurrentSoftware, err := listSoftwareDB(ctx, tx, &host.ID, fleet.SoftwareListOptions{})
if err != nil { if err != nil {
return errors.Wrap(err, "loading current software for host") return errors.Wrap(err, "loading current software for host")
} }
@ -200,44 +200,49 @@ func insertNewInstalledHostSoftwareDB(
return nil return nil
} }
func listSoftwareDB(ctx context.Context, q sqlx.QueryerContext, hostID *uint, teamID *uint, opt fleet.ListOptions) ([]fleet.Software, error) { func listSoftwareDB(ctx context.Context, q sqlx.QueryerContext, hostID *uint, opt fleet.SoftwareListOptions) ([]fleet.Software, error) {
hostWhere := `hs.host_id=?` hostWhere := `hs.host_id=?`
if hostID == nil { if hostID == nil {
hostWhere = "TRUE" hostWhere = "TRUE"
} }
teamWhere := `h.team_id=?` teamWhere := `h.team_id=?`
if teamID == nil { if opt.TeamID == nil {
teamWhere = "TRUE" teamWhere = "TRUE"
} }
vulnerableJoin := "LEFT JOIN software_cpe scp ON (s.id=scp.software_id)"
if opt.Vulnerable {
vulnerableJoin = `JOIN software_cpe scp ON (s.id=scp.software_id)
JOIN software_cve scv ON (scp.id=scv.cpe_id)`
}
sql := fmt.Sprintf(` sql := fmt.Sprintf(`
SELECT DISTINCT s.*, coalesce(scp.cpe, "") as generated_cpe SELECT DISTINCT s.*, coalesce(scp.cpe, "") as generated_cpe
FROM host_software hs FROM host_software hs
JOIN hosts h ON (hs.host_id=h.id) JOIN hosts h ON (hs.host_id=h.id)
JOIN software s ON (hs.software_id=s.id) JOIN software s ON (hs.software_id=s.id)
LEFT JOIN software_cpe scp ON (s.id=scp.software_id) %s
WHERE %s AND %s WHERE %s AND %s
GROUP BY s.id, s.name, s.version, s.source, generated_cpe `, vulnerableJoin, hostWhere, teamWhere)
`, hostWhere, teamWhere)
sql = appendListOptionsToSQL(sql, opt)
var result []*fleet.Software var result []*fleet.Software
vars := []interface{}{} vars := []interface{}{}
if hostID != nil { if hostID != nil {
vars = append(vars, hostID) vars = append(vars, hostID)
} }
if teamID != nil { if opt.TeamID != nil {
vars = append(vars, teamID) vars = append(vars, opt.TeamID)
} }
if err := sqlx.SelectContext(ctx, q, &result, sql, vars...); err != nil { sql, listVars := searchLike(sql, vars, opt.MatchQuery, "s.name", "s.version")
sql += ` GROUP BY s.id, s.name, s.version, s.source, generated_cpe `
sql = appendListOptionsToSQL(sql, opt.ListOptions)
if err := sqlx.SelectContext(ctx, q, &result, sql, listVars...); err != nil {
return nil, errors.Wrap(err, "load host software") return nil, errors.Wrap(err, "load host software")
} }
sql = fmt.Sprintf(` sql = fmt.Sprintf(`
SELECT DISTINCT s.id, scv.cve SELECT DISTINCT hs.software_id, scv.cve
FROM host_software hs FROM host_software hs
JOIN hosts h ON (hs.host_id=h.id) JOIN hosts h ON (hs.host_id=h.id)
JOIN software s ON (s.id=hs.software_id) JOIN software_cpe scp ON (hs.software_id=scp.software_id)
JOIN software_cpe scp ON (s.id=scp.software_id)
JOIN software_cve scv ON (scp.id=scv.cpe_id) JOIN software_cve scv ON (scp.id=scv.cpe_id)
WHERE %s AND %s WHERE %s AND %s
`, hostWhere, teamWhere) `, hostWhere, teamWhere)
@ -275,7 +280,7 @@ func listSoftwareDB(ctx context.Context, q sqlx.QueryerContext, hostID *uint, te
func (d *Datastore) LoadHostSoftware(ctx context.Context, host *fleet.Host) error { func (d *Datastore) LoadHostSoftware(ctx context.Context, host *fleet.Host) error {
host.HostSoftware = fleet.HostSoftware{Modified: false} host.HostSoftware = fleet.HostSoftware{Modified: false}
software, err := listSoftwareDB(ctx, d.reader, &host.ID, nil, fleet.ListOptions{}) software, err := listSoftwareDB(ctx, d.reader, &host.ID, fleet.SoftwareListOptions{})
if err != nil { if err != nil {
return err return err
} }
@ -358,8 +363,8 @@ func (d *Datastore) InsertCVEForCPE(ctx context.Context, cve string, cpes []stri
return nil return nil
} }
func (d *Datastore) ListSoftware(ctx context.Context, teamId *uint, opt fleet.ListOptions) ([]fleet.Software, error) { func (d *Datastore) ListSoftware(ctx context.Context, opt fleet.SoftwareListOptions) ([]fleet.Software, error) {
return listSoftwareDB(ctx, d.reader, nil, teamId, opt) return listSoftwareDB(ctx, d.reader, nil, opt)
} }
func (d *Datastore) SoftwareByID(ctx context.Context, id uint) (*fleet.Software, error) { func (d *Datastore) SoftwareByID(ctx context.Context, id uint) (*fleet.Software, error) {

View File

@ -426,7 +426,7 @@ func testSoftwareList(t *testing.T, ds *Datastore) {
bar003 := fleet.Software{Name: "bar", Version: "0.0.3", Source: "deb_packages"} bar003 := fleet.Software{Name: "bar", Version: "0.0.3", Source: "deb_packages"}
t.Run("lists everything", func(t *testing.T) { t.Run("lists everything", func(t *testing.T) {
software, err := ds.ListSoftware(context.Background(), nil, fleet.ListOptions{}) software, err := ds.ListSoftware(context.Background(), fleet.SoftwareListOptions{})
require.NoError(t, err) require.NoError(t, err)
require.Len(t, software, 4) require.Len(t, software, 4)
@ -435,7 +435,7 @@ func testSoftwareList(t *testing.T, ds *Datastore) {
}) })
t.Run("limits the results", func(t *testing.T) { t.Run("limits the results", func(t *testing.T) {
software, err := ds.ListSoftware(context.Background(), nil, fleet.ListOptions{PerPage: 1, OrderKey: "version"}) software, err := ds.ListSoftware(context.Background(), fleet.SoftwareListOptions{ListOptions: fleet.ListOptions{PerPage: 1, OrderKey: "version"}})
require.NoError(t, err) require.NoError(t, err)
require.Len(t, software, 1) require.Len(t, software, 1)
@ -444,7 +444,7 @@ func testSoftwareList(t *testing.T, ds *Datastore) {
}) })
t.Run("paginates", func(t *testing.T) { t.Run("paginates", func(t *testing.T) {
software, err := ds.ListSoftware(context.Background(), nil, fleet.ListOptions{Page: 1, PerPage: 1, OrderKey: "version"}) software, err := ds.ListSoftware(context.Background(), fleet.SoftwareListOptions{ListOptions: fleet.ListOptions{Page: 1, PerPage: 1, OrderKey: "version"}})
require.NoError(t, err) require.NoError(t, err)
require.Len(t, software, 1) require.Len(t, software, 1)
@ -457,7 +457,7 @@ func testSoftwareList(t *testing.T, ds *Datastore) {
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID})) require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID}))
software, err := ds.ListSoftware(context.Background(), &team1.ID, fleet.ListOptions{OrderKey: "version"}) software, err := ds.ListSoftware(context.Background(), fleet.SoftwareListOptions{ListOptions: fleet.ListOptions{OrderKey: "version"}, TeamID: &team1.ID})
require.NoError(t, err) require.NoError(t, err)
require.Len(t, software, 2) require.Len(t, software, 2)
@ -470,11 +470,29 @@ func testSoftwareList(t *testing.T, ds *Datastore) {
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID})) require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID}))
software, err := ds.ListSoftware(context.Background(), &team1.ID, fleet.ListOptions{PerPage: 1, Page: 1, OrderKey: "id"}) software, err := ds.ListSoftware(context.Background(), fleet.SoftwareListOptions{ListOptions: fleet.ListOptions{PerPage: 1, Page: 1, OrderKey: "id"}, TeamID: &team1.ID})
require.NoError(t, err) require.NoError(t, err)
require.Len(t, software, 1) require.Len(t, software, 1)
expected := []fleet.Software{foo003} expected := []fleet.Software{foo003}
test.ElementsMatchSkipID(t, software, expected) test.ElementsMatchSkipID(t, software, expected)
}) })
t.Run("filters vulnerable software", func(t *testing.T) {
software, err := ds.ListSoftware(context.Background(), fleet.SoftwareListOptions{Vulnerable: true})
require.NoError(t, err)
require.Len(t, software, 1)
expected := []fleet.Software{foo001}
test.ElementsMatchSkipID(t, software, expected)
})
t.Run("filters by query", func(t *testing.T) {
software, err := ds.ListSoftware(context.Background(), fleet.SoftwareListOptions{ListOptions: fleet.ListOptions{MatchQuery: "bar"}})
require.NoError(t, err)
require.Len(t, software, 1)
expected := []fleet.Software{bar003}
test.ElementsMatchSkipID(t, software, expected)
})
} }

View File

@ -241,18 +241,18 @@ const (
// listing objects // listing objects
type ListOptions struct { type ListOptions struct {
// Which page to return (must be positive integer) // Which page to return (must be positive integer)
Page uint Page uint `query:"page,optional"`
// How many results per page (must be positive integer, 0 indicates // How many results per page (must be positive integer, 0 indicates
// unlimited) // unlimited)
PerPage uint PerPage uint `query:"per_page,optional"`
// Key to use for ordering // Key to use for ordering
OrderKey string OrderKey string `query:"order_key,optional"`
// Direction of ordering // Direction of ordering
OrderDirection OrderDirection OrderDirection OrderDirection `query:"order_direction,optional"`
// MatchQuery is the query string to match against columns of the entity // MatchQuery is the query string to match against columns of the entity
// (varies depending on entity, eg. hostname, IP address for hosts). // (varies depending on entity, eg. hostname, IP address for hosts).
// Handling for this parameter must be implemented separately for each type. // Handling for this parameter must be implemented separately for each type.
MatchQuery string MatchQuery string `query:"query,optional"`
} }
type ListQueryOptions struct { type ListQueryOptions struct {

View File

@ -360,7 +360,7 @@ type Datastore interface {
// MigrationStatus returns nil if migrations are complete, and an error if migrations need to be run. // MigrationStatus returns nil if migrations are complete, and an error if migrations need to be run.
MigrationStatus(ctx context.Context) (MigrationStatus, error) MigrationStatus(ctx context.Context) (MigrationStatus, error)
ListSoftware(ctx context.Context, teamId *uint, opt ListOptions) ([]Software, error) ListSoftware(ctx context.Context, opt SoftwareListOptions) ([]Software, error)
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Team Policies // Team Policies

View File

@ -408,7 +408,7 @@ type Service interface {
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Software // Software
ListSoftware(ctx context.Context, teamID *uint, opt ListOptions) ([]Software, error) ListSoftware(ctx context.Context, opt SoftwareListOptions) ([]Software, error)
SoftwareByID(ctx context.Context, id uint) (*Software, error) SoftwareByID(ctx context.Context, id uint) (*Software, error)
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@ -46,3 +46,10 @@ type SoftwareIterator interface {
Err() error Err() error
Close() error Close() error
} }
type SoftwareListOptions struct {
ListOptions
TeamID *uint `query:"team_id,optional"`
Vulnerable bool `query:"vulnerable,optional"`
}

View File

@ -291,7 +291,7 @@ type MigrateDataFunc func(ctx context.Context) error
type MigrationStatusFunc func(ctx context.Context) (fleet.MigrationStatus, error) type MigrationStatusFunc func(ctx context.Context) (fleet.MigrationStatus, error)
type ListSoftwareFunc func(ctx context.Context, teamId *uint, opt fleet.ListOptions) ([]fleet.Software, error) type ListSoftwareFunc func(ctx context.Context, opt fleet.SoftwareListOptions) ([]fleet.Software, error)
type NewTeamPolicyFunc func(ctx context.Context, teamID uint, queryID uint, resolution string) (*fleet.Policy, error) type NewTeamPolicyFunc func(ctx context.Context, teamID uint, queryID uint, resolution string) (*fleet.Policy, error)
@ -1448,9 +1448,9 @@ func (s *DataStore) MigrationStatus(ctx context.Context) (fleet.MigrationStatus,
return s.MigrationStatusFunc(ctx) return s.MigrationStatusFunc(ctx)
} }
func (s *DataStore) ListSoftware(ctx context.Context, teamId *uint, opt fleet.ListOptions) ([]fleet.Software, error) { func (s *DataStore) ListSoftware(ctx context.Context, opt fleet.SoftwareListOptions) ([]fleet.Software, error) {
s.ListSoftwareFuncInvoked = true s.ListSoftwareFuncInvoked = true
return s.ListSoftwareFunc(ctx, teamId, opt) return s.ListSoftwareFunc(ctx, opt)
} }
func (s *DataStore) NewTeamPolicy(ctx context.Context, teamID uint, queryID uint, resolution string) (*fleet.Policy, error) { func (s *DataStore) NewTeamPolicy(ctx context.Context, teamID uint, queryID uint, resolution string) (*fleet.Policy, error) {

View File

@ -148,7 +148,6 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
return nil, err return nil, err
} }
queryVal := r.URL.Query().Get(queryTagValue) queryVal := r.URL.Query().Get(queryTagValue)
if field.Kind() == reflect.Ptr {
// if optional and it's a ptr, leave as nil // if optional and it's a ptr, leave as nil
if queryVal == "" { if queryVal == "" {
if optional { if optional {
@ -156,6 +155,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
} }
return nil, errors.Errorf("Param %s is required", f.Name) return nil, errors.Errorf("Param %s is required", f.Name)
} }
if field.Kind() == reflect.Ptr {
// create the new instance of whatever it is // create the new instance of whatever it is
field.Set(reflect.New(field.Type().Elem())) field.Set(reflect.New(field.Type().Elem()))
field = field.Elem() field = field.Elem()
@ -169,6 +169,8 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
return nil, errors.Wrap(err, "parsing uint from query") return nil, errors.Wrap(err, "parsing uint from query")
} }
field.SetUint(uint64(queryValUint)) field.SetUint(uint64(queryValUint))
case reflect.Bool:
field.SetBool(queryVal == "1" || queryVal == "true")
default: default:
return nil, errors.Errorf("Cant handle type for field %s %s", f.Name, field.Kind()) return nil, errors.Errorf("Cant handle type for field %s %s", f.Name, field.Kind())
} }

View File

@ -281,7 +281,7 @@ func (s *integrationTestSuite) TestVulnerableSoftware() {
SeenTime: time.Now(), SeenTime: time.Now(),
NodeKey: t.Name() + "1", NodeKey: t.Name() + "1",
UUID: t.Name() + "1", UUID: t.Name() + "1",
Hostname: "foo.local", Hostname: t.Name() + "foo.local",
PrimaryIP: "192.168.1.1", PrimaryIP: "192.168.1.1",
PrimaryMac: "30-65-EC-6F-C4-58", PrimaryMac: "30-65-EC-6F-C4-58",
}) })
@ -330,6 +330,13 @@ func (s *integrationTestSuite) TestVulnerableSoftware() {
// ignoring other things like timestamps and things that are outside the cope of this ticket // ignoring other things like timestamps and things that are outside the cope of this ticket
assert.Contains(t, string(bodyBytes), expectedJSONSoft2) assert.Contains(t, string(bodyBytes), expectedJSONSoft2)
assert.Contains(t, string(bodyBytes), expectedJSONSoft1) assert.Contains(t, string(bodyBytes), expectedJSONSoft1)
lsReq := listSoftwareRequest{}
lsResp := listSoftwareResponse{}
s.DoJSON("GET", "/api/v1/fleet/software", lsReq, http.StatusOK, &lsResp, "vulnerable", "true")
assert.Len(t, lsResp.Software, 1)
assert.Equal(t, soft1.ID, lsResp.Software[0].ID)
assert.Len(t, lsResp.Software[0].Vulnerabilities, 1)
} }
func (s *integrationTestSuite) TestGlobalPolicies() { func (s *integrationTestSuite) TestGlobalPolicies() {

View File

@ -11,8 +11,7 @@ import (
///////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////
type listSoftwareRequest struct { type listSoftwareRequest struct {
TeamID *uint `query:"team_id,optional"` fleet.SoftwareListOptions
ListOptions fleet.ListOptions `url:"list_options"`
} }
type listSoftwareResponse struct { type listSoftwareResponse struct {
@ -24,17 +23,17 @@ func (r listSoftwareResponse) error() error { return r.Err }
func listSoftwareEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { func listSoftwareEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*listSoftwareRequest) req := request.(*listSoftwareRequest)
resp, err := svc.ListSoftware(ctx, req.TeamID, req.ListOptions) resp, err := svc.ListSoftware(ctx, req.SoftwareListOptions)
if err != nil { if err != nil {
return listSoftwareResponse{Err: err}, nil return listSoftwareResponse{Err: err}, nil
} }
return listSoftwareResponse{Software: resp}, nil return listSoftwareResponse{Software: resp}, nil
} }
func (svc Service) ListSoftware(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]fleet.Software, error) { func (svc Service) ListSoftware(ctx context.Context, opt fleet.SoftwareListOptions) ([]fleet.Software, error) {
if err := svc.authz.Authorize(ctx, &fleet.Software{}, fleet.ActionRead); err != nil { if err := svc.authz.Authorize(ctx, &fleet.Software{}, fleet.ActionRead); err != nil {
return nil, err return nil, err
} }
return svc.ds.ListSoftware(ctx, teamID, opt) return svc.ds.ListSoftware(ctx, opt)
} }

View File

@ -16,9 +16,9 @@ func TestService_ListSoftware(t *testing.T) {
ds := new(mock.Store) ds := new(mock.Store)
var calledWithTeamID *uint var calledWithTeamID *uint
var calledWithOpt fleet.ListOptions var calledWithOpt fleet.SoftwareListOptions
ds.ListSoftwareFunc = func(ctx context.Context, teamId *uint, opt fleet.ListOptions) ([]fleet.Software, error) { ds.ListSoftwareFunc = func(ctx context.Context, opt fleet.SoftwareListOptions) ([]fleet.Software, error) {
calledWithTeamID = teamId calledWithTeamID = opt.TeamID
calledWithOpt = opt calledWithOpt = opt
return []fleet.Software{}, nil return []fleet.Software{}, nil
} }
@ -29,10 +29,10 @@ func TestService_ListSoftware(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx = viewer.NewContext(ctx, viewer.Viewer{User: user}) ctx = viewer.NewContext(ctx, viewer.Viewer{User: user})
_, err := svc.ListSoftware(ctx, ptr.Uint(42), fleet.ListOptions{PerPage: 77, Page: 4}) _, err := svc.ListSoftware(ctx, fleet.SoftwareListOptions{TeamID: ptr.Uint(42), ListOptions: fleet.ListOptions{PerPage: 77, Page: 4}})
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ds.ListSoftwareFuncInvoked) assert.True(t, ds.ListSoftwareFuncInvoked)
assert.Equal(t, ptr.Uint(42), calledWithTeamID) assert.Equal(t, ptr.Uint(42), calledWithTeamID)
assert.Equal(t, fleet.ListOptions{PerPage: 77, Page: 4}, calledWithOpt) assert.Equal(t, fleet.ListOptions{PerPage: 77, Page: 4}, calledWithOpt.ListOptions)
} }