Add list/detail endpoints for software titles (#15464)

related to #15228
This commit is contained in:
Roberto Dip 2023-12-06 15:28:31 -03:00 committed by GitHub
parent 845ef393bb
commit aada28c1c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1013 additions and 0 deletions

View File

@ -0,0 +1,199 @@
package mysql
import (
"context"
"database/sql"
"fmt"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/jmoiron/sqlx"
)
func (ds *Datastore) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) {
const selectSoftwareTitleStmt = `
SELECT
st.id,
st.name,
st.source,
COUNT(DISTINCT hs.host_id) AS hosts_count,
COUNT(DISTINCT s.id) AS versions_count
FROM software_titles st
JOIN software s ON s.title_id = st.id
JOIN host_software hs ON hs.software_id = s.id
WHERE st.id = ?
GROUP BY st.id
`
var title fleet.SoftwareTitle
if err := sqlx.GetContext(ctx, ds.reader(ctx), &title, selectSoftwareTitleStmt, id); err != nil {
if err == sql.ErrNoRows {
return nil, notFound("SoftwareTitle").WithID(id)
}
return nil, ctxerr.Wrap(ctx, err, "get software title")
}
selectSoftwareVersionsStmt, args, err := selectSoftwareVersionsSQL([]uint{id}, 0, true)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "building versions statement")
}
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &title.Versions, selectSoftwareVersionsStmt, args...); err != nil {
return nil, ctxerr.Wrap(ctx, err, "get software title version")
}
return &title, nil
}
func (ds *Datastore) ListSoftwareTitles(
ctx context.Context,
opt fleet.SoftwareTitleListOptions,
) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) {
dbReader := ds.reader(ctx)
getTitlesStmt, args := selectSoftwareTitlesSQL(opt)
// build the count statement before adding the pagination constraints to `getTitlesStmt`
getTitlesCountStmt := fmt.Sprintf(`SELECT COUNT(DISTINCT s.id) FROM (%s) AS s`, getTitlesStmt)
// grab titles that match the list options
var titles []fleet.SoftwareTitle
getTitlesStmt, args = appendListOptionsWithCursorToSQL(getTitlesStmt, args, &opt.ListOptions)
if err := sqlx.SelectContext(ctx, dbReader, &titles, getTitlesStmt, args...); err != nil {
return nil, 0, nil, ctxerr.Wrap(ctx, err, "select software titles")
}
// perform a second query to grab the counts
var counts int
if err := sqlx.GetContext(ctx, dbReader, &counts, getTitlesCountStmt, args...); err != nil {
return nil, 0, nil, ctxerr.Wrap(ctx, err, "get software titles count")
}
// if we don't have any matching titles, there's no point trying to
// find matching versions. Early return
if len(titles) == 0 {
return titles, counts, &fleet.PaginationMetadata{}, nil
}
// grab all the IDs to find matching versions below
titleIDs := make([]uint, len(titles))
// build an index to quickly access a title by it's ID
titleIndex := make(map[uint]int, len(titles))
for i, title := range titles {
titleIDs[i] = title.ID
titleIndex[title.ID] = i
}
// we grab matching versions separately and build the desired object in
// the application logic. This is because we need to support MySQL 5.7
// and there's no good way to do an aggregation that builds a structure
// (like a JSON) object for nested arrays.
var teamID uint
if opt.TeamID != nil {
teamID = *opt.TeamID
}
getVersionsStmt, args, err := selectSoftwareVersionsSQL(titleIDs, teamID, false)
if err != nil {
return nil, 0, nil, ctxerr.Wrap(ctx, err, "build get versions stmt")
}
var versions []fleet.SoftwareVersion
if err := sqlx.SelectContext(ctx, dbReader, &versions, getVersionsStmt, args...); err != nil {
return nil, 0, nil, ctxerr.Wrap(ctx, err, "get software versions")
}
// append matching versions to titles
for _, version := range versions {
if i, ok := titleIndex[version.TitleID]; ok {
titles[i].Versions = append(titles[i].Versions, version)
}
}
var metaData *fleet.PaginationMetadata
if opt.ListOptions.IncludeMetadata {
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.ListOptions.Page > 0}
if len(titles) > int(opt.ListOptions.PerPage) {
metaData.HasNextResults = true
titles = titles[:len(titles)-1]
}
}
return titles, counts, metaData, nil
}
func selectSoftwareTitlesSQL(opt fleet.SoftwareTitleListOptions) (string, []any) {
stmt := `
SELECT
st.id,
st.name,
st.source,
COUNT(DISTINCT hs.host_id) AS hosts_count,
COUNT(DISTINCT s.id) AS versions_count
FROM software_titles st
JOIN software s ON s.title_id = st.id
JOIN host_software hs ON hs.software_id = s.id
-- placeholder for changing the JOIN type to filter vulnerable software
%s JOIN software_cve scve ON s.id = scve.software_id
-- placeholder for potential JOIN on hosts
%s
-- placeholder for WHERE clause
WHERE %s
GROUP BY st.id`
cveJoinType := "LEFT"
if opt.VulnerableOnly {
cveJoinType = "INNER"
}
var args []any
hostsJoin := ""
whereClause := "TRUE"
if opt.TeamID != nil {
hostsJoin = "JOIN hosts h ON h.id = hs.host_id"
whereClause = "h.team_id = ?"
args = append(args, opt.TeamID)
}
if match := opt.ListOptions.MatchQuery; match != "" {
whereClause += " AND (st.name LIKE ? OR scve.cve LIKE ?)"
match = likePattern(match)
args = append(args, match, match)
}
stmt = fmt.Sprintf(stmt, cveJoinType, hostsJoin, whereClause)
return stmt, args
}
func selectSoftwareVersionsSQL(titleIDs []uint, teamID uint, withCounts bool) (string, []any, error) {
selectVersionsStmt := `
SELECT
st.id as title_id,
s.id, s.version,
%s -- placeholder for optional host_counts
CONCAT('[', GROUP_CONCAT(JSON_QUOTE(scve.cve) SEPARATOR ','), ']') as vulnerabilities
FROM software_titles st
JOIN software s ON s.title_id = st.id
LEFT JOIN host_software hs ON hs.software_id = s.id
LEFT JOIN software_cve scve ON s.id = scve.software_id
%s -- placeholder for optional JOIN ON host_counts
WHERE st.id IN (?)
GROUP BY s.id`
var args []any
extraSelect := ""
extraJoin := ""
if withCounts {
args = append(args, teamID)
extraSelect = "MAX(shc.hosts_count) AS hosts_count,"
extraJoin = `
JOIN software_host_counts shc
ON shc.software_id = s.id
AND shc.hosts_count > 0
AND shc.team_id = ?
`
}
args = append(args, titleIDs)
selectVersionsStmt = fmt.Sprintf(selectVersionsStmt, extraSelect, extraJoin)
selectVersionsStmt, args, err := sqlx.In(selectVersionsStmt, args...)
if err != nil {
return "", nil, fmt.Errorf("bulding sqlx.In query: %w", err)
}
return selectVersionsStmt, args, nil
}

View File

@ -429,6 +429,12 @@ type Datastore interface {
// are being deleted from the global configuration.
DeleteIntegrationsFromTeams(ctx context.Context, deletedIntgs Integrations) error
///////////////////////////////////////////////////////////////////////////////
// Software Titles
ListSoftwareTitles(ctx context.Context, opt SoftwareTitleListOptions) ([]SoftwareTitle, int, *PaginationMetadata, error)
SoftwareTitleByID(ctx context.Context, id uint) (*SoftwareTitle, error)
///////////////////////////////////////////////////////////////////////////////
// SoftwareStore

View File

@ -567,6 +567,12 @@ type Service interface {
SoftwareByID(ctx context.Context, id uint, includeCVEScores bool) (*Software, error)
CountSoftware(ctx context.Context, opt SoftwareListOptions) (int, error)
// /////////////////////////////////////////////////////////////////////////////
// Software Titles
ListSoftwareTitles(ctx context.Context, opt SoftwareTitleListOptions) ([]SoftwareTitle, int, *PaginationMetadata, error)
SoftwareTitleByID(ctx context.Context, id uint) (*SoftwareTitle, error)
// /////////////////////////////////////////////////////////////////////////////
// Team Policies

View File

@ -1,6 +1,7 @@
package fleet
import (
"encoding/json"
"errors"
"fmt"
"strconv"
@ -102,12 +103,57 @@ func (s Software) ToUniqueStr() string {
return strings.Join(ss, SoftwareFieldSeparator)
}
type SliceString []string
func (c *SliceString) Scan(v interface{}) error {
switch tv := v.(type) {
case []byte:
return json.Unmarshal(tv, &c)
}
return errors.New("unsupported type")
}
// SoftwareVersion is an abstraction over the `software` table to support the
// software titles APIs
type SoftwareVersion struct {
ID uint `db:"id" json:"id"`
// Version is the version string we grab for this specific software.
Version string `db:"version" json:"version"`
// Vulnerabilities is the list of CVE names for vulnerabilities found for this version.
Vulnerabilities *SliceString `db:"vulnerabilities" json:"vulnerabilities,omitempty"`
// HostsCount is the number of hosts that use this software version.
HostsCount *uint `db:"hosts_count" json:"hosts_count,omitempty"`
// TitleID is used only as an auxiliary field and it's not part of the
// JSON response.
TitleID uint `db:"title_id" json:"-"`
}
// SoftwareTitle represents a title backed by the `software_titles` table.
type SoftwareTitle struct {
ID uint `json:"id" db:"id"`
// Name is the name reported by osquery.
Name string `json:"name" db:"name"`
// Source is the source reported by osquery.
Source string `json:"source" db:"source"`
// HostsCount is the number of hosts that use this software title.
HostsCount uint `json:"hosts_count" db:"hosts_count"`
// VesionsCount is the number of versions that have the same title.
VersionsCount uint `json:"versions_count" db:"versions_count"`
// Versions countains information about the versions that use this title.
Versions []SoftwareVersion `json:"versions" db:"-"`
// CountsUpdatedAt is the timestamp when the hosts count
// was last updated for that software, filled only if hosts
// count is requested.
CountsUpdatedAt time.Time `json:"-" db:"counts_updated_at"`
}
type SoftwareTitleListOptions struct {
// ListOptions cannot be embedded in order to unmarshall with validation.
ListOptions ListOptions `url:"list_options"`
TeamID *uint `query:"team_id,optional"`
VulnerableOnly bool `query:"vulnerable,optional"`
}
// AuthzSoftwareInventory is used for access controls on software inventory.

View File

@ -324,6 +324,10 @@ type TeamEnrollSecretsFunc func(ctx context.Context, teamID uint) ([]*fleet.Enro
type DeleteIntegrationsFromTeamsFunc func(ctx context.Context, deletedIntgs fleet.Integrations) error
type ListSoftwareTitlesFunc func(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error)
type SoftwareTitleByIDFunc func(ctx context.Context, id uint) (*fleet.SoftwareTitle, error)
type ListSoftwareForVulnDetectionFunc func(ctx context.Context, hostID uint) ([]fleet.Software, error)
type ListSoftwareVulnerabilitiesByHostIDsSourceFunc func(ctx context.Context, hostIDs []uint, source fleet.VulnerabilitySource) (map[uint][]fleet.SoftwareVulnerability, error)
@ -1224,6 +1228,12 @@ type DataStore struct {
DeleteIntegrationsFromTeamsFunc DeleteIntegrationsFromTeamsFunc
DeleteIntegrationsFromTeamsFuncInvoked bool
ListSoftwareTitlesFunc ListSoftwareTitlesFunc
ListSoftwareTitlesFuncInvoked bool
SoftwareTitleByIDFunc SoftwareTitleByIDFunc
SoftwareTitleByIDFuncInvoked bool
ListSoftwareForVulnDetectionFunc ListSoftwareForVulnDetectionFunc
ListSoftwareForVulnDetectionFuncInvoked bool
@ -2958,6 +2968,20 @@ func (s *DataStore) DeleteIntegrationsFromTeams(ctx context.Context, deletedIntg
return s.DeleteIntegrationsFromTeamsFunc(ctx, deletedIntgs)
}
func (s *DataStore) ListSoftwareTitles(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) {
s.mu.Lock()
s.ListSoftwareTitlesFuncInvoked = true
s.mu.Unlock()
return s.ListSoftwareTitlesFunc(ctx, opt)
}
func (s *DataStore) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) {
s.mu.Lock()
s.SoftwareTitleByIDFuncInvoked = true
s.mu.Unlock()
return s.SoftwareTitleByIDFunc(ctx, id)
}
func (s *DataStore) ListSoftwareForVulnDetection(ctx context.Context, hostID uint) ([]fleet.Software, error) {
s.mu.Lock()
s.ListSoftwareForVulnDetectionFuncInvoked = true

View File

@ -368,6 +368,9 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
// DEPRECATED: software version counts are now included directly in the software version list
ue.GET("/api/_version_/fleet/software/count", countSoftwareEndpoint, countSoftwareRequest{})
ue.GET("/api/_version_/fleet/software/titles", listSoftwareTitlesEndpoint, listSoftwareTitlesRequest{})
ue.GET("/api/_version_/fleet/software/titles/{id:[0-9]+}", getSoftwareTitleEndpoint, getSoftwareTitleRequest{})
ue.GET("/api/_version_/fleet/host_summary", getHostSummaryEndpoint, getHostSummaryRequest{})
ue.GET("/api/_version_/fleet/hosts", listHostsEndpoint, listHostsRequest{})
ue.POST("/api/_version_/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{})

View File

@ -5005,6 +5005,21 @@ func (s *integrationTestSuite) TestPremiumEndpointsWithoutLicense() {
// batch set scripts
s.Do("POST", "/api/v1/fleet/scripts/batch", batchSetScriptsRequest{Scripts: nil}, http.StatusPaymentRequired)
// software titles
// a normal request works fine
var resp listSoftwareTitlesResponse
s.DoJSON("GET", "/api/latest/fleet/software/titles", listSoftwareTitlesRequest{}, http.StatusOK, &resp)
require.Equal(t, 0, resp.Count)
require.Nil(t, resp.SoftwareTitles)
// a request with a team_id parameter returns a license error
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{}, http.StatusPaymentRequired, &resp,
"team_id", "1",
)
}
// TestGlobalPoliciesBrowsing tests that team users can browse (read) global policies (see #3722).

View File

@ -3665,6 +3665,8 @@ func (s *integrationEnterpriseTestSuite) TestGitOpsUserActions() {
s.DoJSON("GET", "/api/latest/fleet/software/versions", listSoftwareRequest{}, http.StatusForbidden, &listSoftwareVersionsResponse{})
s.DoJSON("GET", "/api/latest/fleet/software", listSoftwareRequest{}, http.StatusForbidden, &listSoftwareResponse{})
s.DoJSON("GET", "/api/latest/fleet/software/count", countSoftwareRequest{}, http.StatusForbidden, &countSoftwareResponse{})
s.DoJSON("GET", "/api/latest/fleet/software/titles", listSoftwareTitlesRequest{}, http.StatusForbidden, &listSoftwareTitlesResponse{})
s.DoJSON("GET", "/api/latest/fleet/software/titles/1", getSoftwareTitleRequest{}, http.StatusForbidden, &getSoftwareTitleResponse{})
// Attempt to list a software, should fail.
s.DoJSON("GET", "/api/latest/fleet/software/1", getSoftwareRequest{}, http.StatusForbidden, &getSoftwareResponse{})
@ -5689,6 +5691,431 @@ func (s *integrationEnterpriseTestSuite) TestTeamConfigDetailQueriesOverrides()
require.Contains(t, dqResp.Queries, fmt.Sprintf("fleet_distributed_query_%s", t.Name()))
}
func (s *integrationEnterpriseTestSuite) TestAllSoftwareTitles() {
ctx := context.Background()
t := s.T()
softwareTitlesMatch := func(want, got []fleet.SoftwareTitle) {
// compare only the fields we care about
for i := range got {
require.NotZero(t, got[i].ID)
got[i].ID = 0
for j := range got[i].Versions {
require.NotZero(t, got[i].Versions[j].ID)
got[i].Versions[j].ID = 0
}
}
// sort and use EqualValues instead of ElementsMatch in order
// to do a deep comparison of nested structures
sort.Slice(got, func(i, j int) bool {
return got[i].Name < got[j].Name
})
sort.Slice(want, func(i, j int) bool {
return want[i].Name < want[j].Name
})
require.EqualValues(t, want, got)
}
host, err := s.ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now().Add(-1 * time.Minute),
OsqueryHostID: ptr.String(t.Name()),
NodeKey: ptr.String(t.Name()),
UUID: uuid.New().String(),
Hostname: fmt.Sprintf("%sfoo.local", t.Name()),
Platform: "darwin",
})
require.NoError(t, err)
tmHost, err := s.ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now().Add(-1 * time.Minute),
OsqueryHostID: ptr.String(t.Name() + "tm"),
NodeKey: ptr.String(t.Name() + "tm"),
UUID: uuid.New().String(),
Hostname: fmt.Sprintf("%sfoo.local", t.Name()+"tm"),
Platform: "linux",
})
require.NoError(t, err)
// create a couple of teams and add tmHost to one
team1, err := s.ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team1"})
require.NoError(t, err)
_, err = s.ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team2"})
require.NoError(t, err)
require.NoError(t, s.ds.AddHostsToTeam(ctx, &team1.ID, []uint{tmHost.ID}))
software := []fleet.Software{
{Name: "foo", Version: "0.0.1", Source: "homebrew"},
{Name: "foo", Version: "0.0.3", Source: "homebrew"},
{Name: "bar", Version: "0.0.4", Source: "apps"},
}
_, err = s.ds.UpdateHostSoftware(context.Background(), host.ID, software)
require.NoError(t, err)
require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host, false))
soft1 := host.Software[0]
if soft1.Name != "bar" {
soft1 = host.Software[1]
}
cpes := []fleet.SoftwareCPE{{SoftwareID: soft1.ID, CPE: "somecpe"}}
_, err = s.ds.UpsertSoftwareCPEs(context.Background(), cpes)
require.NoError(t, err)
// Reload software so that 'GeneratedCPEID is set.
require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host, false))
soft1 = host.Software[0]
if soft1.Name != "bar" {
soft1 = host.Software[1]
}
inserted, err := s.ds.InsertSoftwareVulnerability(
context.Background(), fleet.SoftwareVulnerability{
SoftwareID: soft1.ID,
CVE: "cve-123-123-132",
}, fleet.NVDSource,
)
require.NoError(t, err)
require.True(t, inserted)
// calculate hosts counts
hostsCountTs := time.Now().UTC()
require.NoError(t, s.ds.SyncHostsSoftware(context.Background(), hostsCountTs))
require.NoError(t, s.ds.ReconcileSoftwareTitles(ctx))
t.Run("GET /software/titles", func(t *testing.T) {
var resp listSoftwareTitlesResponse
s.DoJSON("GET", "/api/latest/fleet/software/titles", listSoftwareTitlesRequest{}, http.StatusOK, &resp)
require.Equal(t, 2, resp.Count)
softwareTitlesMatch([]fleet.SoftwareTitle{
{
Name: "foo",
Source: "homebrew",
VersionsCount: 2,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.1", Vulnerabilities: nil},
{Version: "0.0.3", Vulnerabilities: nil},
},
},
{
Name: "bar",
Source: "apps",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}},
},
},
}, resp.SoftwareTitles)
// per_page equals 1, so we get only one item, but the total count is
// still 2
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"per_page", "1",
"order_key", "name",
"order_direction", "desc",
)
require.Equal(t, 2, resp.Count)
softwareTitlesMatch([]fleet.SoftwareTitle{
{
Name: "foo",
Source: "homebrew",
VersionsCount: 2,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.1", Vulnerabilities: nil},
{Version: "0.0.3", Vulnerabilities: nil},
},
},
}, resp.SoftwareTitles)
// get the second item
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"per_page", "1",
"page", "1",
"order_key", "name",
"order_direction", "desc",
)
require.Equal(t, 2, resp.Count)
softwareTitlesMatch([]fleet.SoftwareTitle{
{
Name: "bar",
Source: "apps",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}},
},
},
}, resp.SoftwareTitles)
// asking for a non-existent page returns an empty list
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"per_page", "1",
"page", "4",
"order_key", "name",
"order_direction", "desc",
)
require.Equal(t, 2, resp.Count)
softwareTitlesMatch(nil, resp.SoftwareTitles)
// asking for vulnerable only software returns the expected values
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"vulnerable", "true",
)
require.Equal(t, 1, resp.Count)
softwareTitlesMatch([]fleet.SoftwareTitle{
{
Name: "bar",
Source: "apps",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}},
},
},
}, resp.SoftwareTitles)
// request titles for team1, nothing there yet
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"team_id", "1",
)
require.Equal(t, 0, resp.Count)
softwareTitlesMatch(nil, resp.SoftwareTitles)
// add new software for tmHost
software = []fleet.Software{
{Name: "foo", Version: "0.0.1", Source: "homebrew"},
{Name: "baz", Version: "0.0.5", Source: "deb_packages"},
}
_, err = s.ds.UpdateHostSoftware(context.Background(), tmHost.ID, software)
require.NoError(t, err)
// calculate hosts counts
hostsCountTs := time.Now().UTC()
require.NoError(t, s.ds.SyncHostsSoftware(context.Background(), hostsCountTs))
require.NoError(t, s.ds.ReconcileSoftwareTitles(ctx))
// request software for the team, this time we get results
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"team_id", "1",
"order_key", "name",
"order_direction", "desc",
)
require.Equal(t, 2, resp.Count)
softwareTitlesMatch([]fleet.SoftwareTitle{
{
Name: "baz",
Source: "deb_packages",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.5", Vulnerabilities: nil},
},
},
{
Name: "foo",
Source: "homebrew",
VersionsCount: 1, // NOTE: this value is 1 because the team has only 1 matching host
HostsCount: 1, // NOTE: this value is 1 because the team has only 1 matching host
Versions: []fleet.SoftwareVersion{
{Version: "0.0.1", Vulnerabilities: nil},
{Version: "0.0.3", Vulnerabilities: nil},
},
},
}, resp.SoftwareTitles)
// request software for no-team, we get all results and 2 hosts for
// `"foo"`
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"order_key", "name",
"order_direction", "desc",
)
require.Equal(t, 3, resp.Count)
softwareTitlesMatch([]fleet.SoftwareTitle{
{
Name: "baz",
Source: "deb_packages",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.5", Vulnerabilities: nil},
},
},
{
Name: "foo",
Source: "homebrew",
VersionsCount: 2, // NOTE: this value is 2, important because no team filter was applied
HostsCount: 2, // NOTE: this value is 2, important because no team filter was applied
Versions: []fleet.SoftwareVersion{
{Version: "0.0.1", Vulnerabilities: nil},
{Version: "0.0.3", Vulnerabilities: nil},
},
},
{
Name: "bar",
Source: "apps",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}},
},
},
}, resp.SoftwareTitles)
// match cve by name
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"query", "123",
)
require.Equal(t, 1, resp.Count)
softwareTitlesMatch([]fleet.SoftwareTitle{
{
Name: "bar",
Source: "apps",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}},
},
},
}, resp.SoftwareTitles)
// match software title by name
resp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &resp,
"query", "ba",
)
require.Equal(t, 2, resp.Count)
softwareTitlesMatch([]fleet.SoftwareTitle{
{
Name: "bar",
Source: "apps",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}},
},
},
{
Name: "baz",
Source: "deb_packages",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.5", Vulnerabilities: nil},
},
},
}, resp.SoftwareTitles)
})
t.Run("GET /software/titles/:id", func(t *testing.T) {
// find the ID of "foo"
var softwareListResp listSoftwareTitlesResponse
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &softwareListResp,
"query", "foo",
)
require.Equal(t, 1, softwareListResp.Count)
require.Len(t, softwareListResp.SoftwareTitles, 1)
fooTitle := softwareListResp.SoftwareTitles[0]
require.Equal(t, "foo", fooTitle.Name)
// non-existent id is a 404
var resp getSoftwareTitleResponse
s.DoJSON("GET", "/api/latest/fleet/software/titles/999", getSoftwareTitleRequest{}, http.StatusNotFound, &resp)
// valid title
resp = getSoftwareTitleResponse{}
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", fooTitle.ID), getSoftwareTitleRequest{}, http.StatusOK, &resp)
softwareTitlesMatch([]fleet.SoftwareTitle{{
Name: "foo",
Source: "homebrew",
VersionsCount: 2,
HostsCount: 2,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.1", Vulnerabilities: nil, HostsCount: ptr.Uint(2)},
{Version: "0.0.3", Vulnerabilities: nil, HostsCount: ptr.Uint(1)},
}},
}, []fleet.SoftwareTitle{*resp.SoftwareTitle})
// find the ID of "bar"
softwareListResp = listSoftwareTitlesResponse{}
s.DoJSON(
"GET", "/api/latest/fleet/software/titles",
listSoftwareTitlesRequest{},
http.StatusOK, &softwareListResp,
"query", "bar",
)
require.Equal(t, 1, softwareListResp.Count)
require.Len(t, softwareListResp.SoftwareTitles, 1)
barTitle := softwareListResp.SoftwareTitles[0]
require.Equal(t, "bar", barTitle.Name)
// valid title with vulnerabilities
resp = getSoftwareTitleResponse{}
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", barTitle.ID), getSoftwareTitleRequest{}, http.StatusOK, &resp)
softwareTitlesMatch([]fleet.SoftwareTitle{{
Name: "bar",
Source: "apps",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{
Version: "0.0.4",
Vulnerabilities: &fleet.SliceString{"cve-123-123-132"},
HostsCount: ptr.Uint(1),
},
}},
}, []fleet.SoftwareTitle{*resp.SoftwareTitle})
})
}
// checks that the specified team/no-team has the Windows OS Updates profile with
// the specified deadline/grace settings (or checks that it doesn't have the
// profile if wantSettings is nil). It returns the profile_uuid if it exists,

View File

@ -0,0 +1,126 @@
package service
import (
"context"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
)
/////////////////////////////////////////////////////////////////////////////////
// List Software Titles
/////////////////////////////////////////////////////////////////////////////////
type listSoftwareTitlesRequest struct {
fleet.SoftwareTitleListOptions
}
type listSoftwareTitlesResponse struct {
Meta *fleet.PaginationMetadata `json:"meta"`
Count int `json:"count"`
CountsUpdatedAt *time.Time `json:"counts_updated_at"`
SoftwareTitles []fleet.SoftwareTitle `json:"software_titles,omitempty"`
Err error `json:"error,omitempty"`
}
func (r listSoftwareTitlesResponse) error() error { return r.Err }
func listSoftwareTitlesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*listSoftwareTitlesRequest)
titles, count, meta, err := svc.ListSoftwareTitles(ctx, req.SoftwareTitleListOptions)
if err != nil {
return listSoftwareTitlesResponse{Err: err}, nil
}
var latest time.Time
for _, sw := range titles {
if !sw.CountsUpdatedAt.IsZero() && sw.CountsUpdatedAt.After(latest) {
latest = sw.CountsUpdatedAt
}
}
listResp := listSoftwareTitlesResponse{
SoftwareTitles: titles,
Count: count,
Meta: meta,
}
if !latest.IsZero() {
listResp.CountsUpdatedAt = &latest
}
return listResp, nil
}
func (svc *Service) ListSoftwareTitles(
ctx context.Context,
opt fleet.SoftwareTitleListOptions,
) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) {
if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{
TeamID: opt.TeamID,
}, fleet.ActionRead); err != nil {
return nil, 0, nil, err
}
if opt.TeamID != nil && *opt.TeamID != 0 {
lic, err := svc.License(ctx)
if err != nil {
return nil, 0, nil, ctxerr.Wrap(ctx, err, "get license")
}
if !lic.IsPremium() {
return nil, 0, nil, fleet.ErrMissingLicense
}
}
// always include metadata for software titles
opt.ListOptions.IncludeMetadata = true
// cursor-based pagination is not supported for software titles
opt.ListOptions.After = ""
titles, count, meta, err := svc.ds.ListSoftwareTitles(ctx, opt)
if err != nil {
return nil, 0, nil, err
}
return titles, count, meta, nil
}
/////////////////////////////////////////////////////////////////////////////////
// Get a Software Title
/////////////////////////////////////////////////////////////////////////////////
type getSoftwareTitleRequest struct {
ID uint `url:"id"`
}
type getSoftwareTitleResponse struct {
SoftwareTitle *fleet.SoftwareTitle `json:"software_title,omitempty"`
Err error `json:"error,omitempty"`
}
func (r getSoftwareTitleResponse) error() error { return r.Err }
func getSoftwareTitleEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*getSoftwareTitleRequest)
software, err := svc.SoftwareTitleByID(ctx, req.ID)
if err != nil {
return getSoftwareTitleResponse{Err: err}, nil
}
return getSoftwareTitleResponse{SoftwareTitle: software}, nil
}
func (svc *Service) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) {
// TODO: this is the autorization we do for GET /software, does it look right?
// checking with product here: https://github.com/fleetdm/fleet/issues/14674#issuecomment-1841395788
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
return nil, err
}
software, err := svc.ds.SoftwareTitleByID(ctx, id)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting software title by id")
}
return software, nil
}

View File

@ -0,0 +1,161 @@
package service
import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/contexts/license"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/require"
)
func TestServiceSoftwareTitlesAuth(t *testing.T) {
ds := new(mock.Store)
ds.ListSoftwareTitlesFunc = func(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) {
return []fleet.SoftwareTitle{}, 0, &fleet.PaginationMetadata{}, nil
}
ds.SoftwareTitleByIDFunc = func(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) {
return &fleet.SoftwareTitle{}, nil
}
svc, ctx := newTestService(t, ds, nil, nil)
for _, tc := range []struct {
name string
user *fleet.User
shouldFailGlobalRead bool
shouldFailTeamRead bool
}{
{
name: "global-admin",
user: &fleet.User{
ID: 1,
GlobalRole: ptr.String(fleet.RoleAdmin),
},
shouldFailGlobalRead: false,
shouldFailTeamRead: false,
},
{
name: "global-maintainer",
user: &fleet.User{
ID: 1,
GlobalRole: ptr.String(fleet.RoleMaintainer),
},
shouldFailGlobalRead: false,
shouldFailTeamRead: false,
},
{
name: "global-observer",
user: &fleet.User{
ID: 1,
GlobalRole: ptr.String(fleet.RoleObserver),
},
shouldFailGlobalRead: false,
shouldFailTeamRead: false,
},
{
name: "team-admin-belongs-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 1},
Role: fleet.RoleAdmin,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: false,
},
{
name: "team-maintainer-belongs-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 1},
Role: fleet.RoleMaintainer,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: false,
},
{
name: "team-observer-belongs-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 1},
Role: fleet.RoleObserver,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: false,
},
{
name: "team-admin-does-not-belong-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 2},
Role: fleet.RoleAdmin,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: true,
},
{
name: "team-maintainer-does-not-belong-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 2},
Role: fleet.RoleMaintainer,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: true,
},
{
name: "team-observer-does-not-belong-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 2},
Role: fleet.RoleObserver,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := viewer.NewContext(ctx, viewer.Viewer{User: tc.user})
premiumCtx := license.NewContext(ctx, &fleet.LicenseInfo{Tier: fleet.TierPremium})
// List all software titles.
_, _, _, err := svc.ListSoftwareTitles(ctx, fleet.SoftwareTitleListOptions{})
checkAuthErr(t, tc.shouldFailGlobalRead, err)
// List software for a team.
_, _, _, err = svc.ListSoftwareTitles(premiumCtx, fleet.SoftwareTitleListOptions{
TeamID: ptr.Uint(1),
})
checkAuthErr(t, tc.shouldFailTeamRead, err)
// List software for a team should fail no matter what
// with a non-premium context
if !tc.shouldFailTeamRead {
_, _, _, err = svc.ListSoftwareTitles(ctx, fleet.SoftwareTitleListOptions{
TeamID: ptr.Uint(1),
})
require.ErrorContains(t, err, "Requires Fleet Premium license")
}
// Get a software title
_, err = svc.SoftwareTitleByID(ctx, 1)
checkAuthErr(t, false, err)
})
}
}