Add team filter to software detail APIs (#16876)

#16787
This commit is contained in:
Victor Lyuboslavsky 2024-02-18 07:14:20 -06:00 committed by GitHub
parent 443ec3c622
commit 35ca4ee32b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 367 additions and 128 deletions

View File

@ -12,7 +12,7 @@ func (svc *Service) ListSoftware(ctx context.Context, opts fleet.SoftwareListOpt
return svc.Service.ListSoftware(ctx, opts)
}
func (svc *Service) SoftwareByID(ctx context.Context, id uint, includeCVEScores bool) (*fleet.Software, error) {
func (svc *Service) SoftwareByID(ctx context.Context, id uint, teamID *uint, _ bool) (*fleet.Software, error) {
// reuse SoftwareByID, but include cve scores in premium version
return svc.Service.SoftwareByID(ctx, id, true)
return svc.Service.SoftwareByID(ctx, id, teamID, true)
}

View File

@ -4419,6 +4419,9 @@ func (ds *Datastore) UpdateHost(ctx context.Context, host *fleet.Host) error {
func (ds *Datastore) OSVersion(ctx context.Context, osVersionID uint, teamID *uint) (*fleet.OSVersion, *time.Time, error) {
jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, notFound("OSVersion")
}
return nil, nil, err
}

View File

@ -1124,7 +1124,7 @@ func (ds *Datastore) DeleteOutOfDateVulnerabilities(ctx context.Context, source
return nil
}
func (ds *Datastore) SoftwareByID(ctx context.Context, id uint, includeCVEScores bool) (*fleet.Software, error) {
func (ds *Datastore) SoftwareByID(ctx context.Context, id uint, teamID *uint, includeCVEScores bool) (*fleet.Software, error) {
q := dialect.From(goqu.I("software").As("s")).
Select(
"s.id",
@ -1169,7 +1169,16 @@ func (ds *Datastore) SoftwareByID(ctx context.Context, id uint, includeCVEScores
q = q.Where(goqu.I("s.id").Eq(id))
// filter software that is not associated with any hosts
q = q.Where(goqu.L("EXISTS (SELECT 1 FROM host_software WHERE software_id = ? LIMIT 1)", id))
if teamID == nil {
q = q.Where(goqu.L("EXISTS (SELECT 1 FROM host_software WHERE software_id = ? LIMIT 1)", id))
} else {
// if teamID filter is used, host counts need to be up-to-date
q = q.Where(
goqu.L(
"EXISTS (SELECT 1 FROM software_host_counts WHERE software_id = ? AND team_id = ? AND hosts_count > 0)", id, *teamID,
),
)
}
sql, args, err := q.ToSQL()
if err != nil {

View File

@ -3,6 +3,7 @@ package mysql
import (
"context"
"crypto/md5" // nolint:gosec (only used for tests)
"database/sql"
"encoding/hex"
"fmt"
"math/rand"
@ -101,7 +102,7 @@ func testSoftwareSaveHost(t *testing.T, ds *Datastore) {
host1Software := getHostSoftware(host1)
test.ElementsMatchSkipIDAndHostCount(t, software1, host1Software)
soft1ByID, err := ds.SoftwareByID(context.Background(), host1.HostSoftware.Software[0].ID, false)
soft1ByID, err := ds.SoftwareByID(context.Background(), host1.HostSoftware.Software[0].ID, nil, false)
require.NoError(t, err)
require.NotNil(t, soft1ByID)
assert.Equal(t, host1Software[0], *soft1ByID)
@ -292,7 +293,7 @@ func testSoftwareLoadVulnerabilities(t *testing.T, ds *Datastore) {
}
require.NoError(t, ds.LoadHostSoftware(context.Background(), host, false))
softByID, err := ds.SoftwareByID(context.Background(), host.HostSoftware.Software[0].ID, false)
softByID, err := ds.SoftwareByID(context.Background(), host.HostSoftware.Software[0].ID, nil, false)
require.NoError(t, err)
require.NotNil(t, softByID)
require.Len(t, softByID.Vulnerabilities, 2)
@ -1050,6 +1051,10 @@ func testSoftwareSyncHostsSoftware(t *testing.T, ds *Datastore) {
want = []fleet.Software{}
cmpNameVersionCount(want, team1Counts)
checkTableTotalCount(3)
require.NoError(t, ds.LoadHostSoftware(context.Background(), host1, false))
nilSoftware, err := ds.SoftwareByID(context.Background(), host1.HostSoftware.Software[0].ID, &team1.ID, false)
assert.Nil(t, nilSoftware)
assert.ErrorIs(t, err, sql.ErrNoRows)
// after a call to Calculate, the global counts are updated and the team counts appear
require.NoError(t, ds.SyncHostsSoftware(ctx, time.Now()))
@ -1073,6 +1078,11 @@ func testSoftwareSyncHostsSoftware(t *testing.T, ds *Datastore) {
// composite pk (software_id, team_id), so we expect more rows
checkTableTotalCount(8)
soft1ByID, err := ds.SoftwareByID(context.Background(), host1.HostSoftware.Software[0].ID, &team1.ID, false)
require.NoError(t, err)
software1[0].ID = host1.HostSoftware.Software[0].ID
assert.Equal(t, software1[0], *soft1ByID)
team2Opts := fleet.SoftwareListOptions{WithHostCounts: true, TeamID: ptr.Uint(team2.ID), ListOptions: fleet.ListOptions{OrderKey: "hosts_count", OrderDirection: fleet.OrderDescending}}
team2Counts := listSoftwareCheckCount(t, ds, 2, 2, team2Opts, false)
want = []fleet.Software{
@ -2000,7 +2010,7 @@ func testSoftwareByIDNoDuplicatedVulns(t *testing.T, ds *Datastore) {
}
for _, s := range hostA.Software {
result, err := ds.SoftwareByID(ctx, s.ID, true)
result, err := ds.SoftwareByID(ctx, s.ID, nil, true)
require.NoError(t, err)
require.Len(t, result.Vulnerabilities, 1)
}
@ -2011,6 +2021,9 @@ func testSoftwareByIDIncludesCVEPublishedDate(t *testing.T, ds *Datastore) {
t.Run("software.vulnerabilities includes the published date", func(t *testing.T) {
ctx := context.Background()
host := test.NewHost(t, ds, "hostA", "", "hostAkey", "hostAuuid", time.Now())
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host.ID}))
now := time.Now().UTC().Truncate(time.Second)
testCases := []struct {
@ -2036,9 +2049,10 @@ func testSoftwareByIDIncludesCVEPublishedDate(t *testing.T, ds *Datastore) {
Source: "apps",
})
}
_, err := ds.UpdateHostSoftware(ctx, host.ID, software)
_, err = ds.UpdateHostSoftware(ctx, host.ID, software)
require.NoError(t, err)
require.NoError(t, ds.LoadHostSoftware(ctx, host, false))
require.NoError(t, ds.SyncHostsSoftware(ctx, time.Now()))
// Add vulnerabilities and CVEMeta
var meta []fleet.CVEMeta
@ -2088,47 +2102,49 @@ func testSoftwareByIDIncludesCVEPublishedDate(t *testing.T, ds *Datastore) {
}
require.NotEqual(t, -1, idx, "software not found")
// Test that scores are not included if includeCVEScores = false
withoutScores, err := ds.SoftwareByID(ctx, host.Software[idx].ID, false)
require.NoError(t, err)
if tC.hasVuln {
require.Len(t, withoutScores.Vulnerabilities, 1)
require.Equal(t, fmt.Sprintf("cve-%s", tC.name), withoutScores.Vulnerabilities[0].CVE)
for _, teamID := range []*uint{nil, &team1.ID} {
// Test that scores are not included if includeCVEScores = false
withoutScores, err := ds.SoftwareByID(ctx, host.Software[idx].ID, teamID, false)
require.NoError(t, err)
if tC.hasVuln {
require.Len(t, withoutScores.Vulnerabilities, 1)
require.Equal(t, fmt.Sprintf("cve-%s", tC.name), withoutScores.Vulnerabilities[0].CVE)
require.Nil(t, withoutScores.Vulnerabilities[0].CVSSScore)
require.Nil(t, withoutScores.Vulnerabilities[0].EPSSProbability)
require.Nil(t, withoutScores.Vulnerabilities[0].CISAKnownExploit)
} else {
require.Empty(t, withoutScores.Vulnerabilities)
}
withScores, err := ds.SoftwareByID(ctx, host.Software[idx].ID, true)
require.NoError(t, err)
if tC.hasVuln {
require.Len(t, withScores.Vulnerabilities, 1)
require.Equal(t, fmt.Sprintf("cve-%s", tC.name), withoutScores.Vulnerabilities[0].CVE)
if tC.hasMeta {
require.NotNil(t, withScores.Vulnerabilities[0].CVSSScore)
require.NotNil(t, *withScores.Vulnerabilities[0].CVSSScore)
require.Equal(t, **withScores.Vulnerabilities[0].CVSSScore, 5.4)
require.NotNil(t, withScores.Vulnerabilities[0].EPSSProbability)
require.NotNil(t, *withScores.Vulnerabilities[0].EPSSProbability)
require.Equal(t, **withScores.Vulnerabilities[0].EPSSProbability, 0.5)
require.NotNil(t, withScores.Vulnerabilities[0].CISAKnownExploit)
require.NotNil(t, *withScores.Vulnerabilities[0].CISAKnownExploit)
require.Equal(t, **withScores.Vulnerabilities[0].CISAKnownExploit, true)
if tC.hasPublishedDate {
require.NotNil(t, withScores.Vulnerabilities[0].CVEPublished)
require.NotNil(t, *withScores.Vulnerabilities[0].CVEPublished)
require.Equal(t, (**withScores.Vulnerabilities[0].CVEPublished), now)
}
require.Nil(t, withoutScores.Vulnerabilities[0].CVSSScore)
require.Nil(t, withoutScores.Vulnerabilities[0].EPSSProbability)
require.Nil(t, withoutScores.Vulnerabilities[0].CISAKnownExploit)
} else {
require.Empty(t, withoutScores.Vulnerabilities)
}
withScores, err := ds.SoftwareByID(ctx, host.Software[idx].ID, teamID, true)
require.NoError(t, err)
if tC.hasVuln {
require.Len(t, withScores.Vulnerabilities, 1)
require.Equal(t, fmt.Sprintf("cve-%s", tC.name), withoutScores.Vulnerabilities[0].CVE)
if tC.hasMeta {
require.NotNil(t, withScores.Vulnerabilities[0].CVSSScore)
require.NotNil(t, *withScores.Vulnerabilities[0].CVSSScore)
require.Equal(t, **withScores.Vulnerabilities[0].CVSSScore, 5.4)
require.NotNil(t, withScores.Vulnerabilities[0].EPSSProbability)
require.NotNil(t, *withScores.Vulnerabilities[0].EPSSProbability)
require.Equal(t, **withScores.Vulnerabilities[0].EPSSProbability, 0.5)
require.NotNil(t, withScores.Vulnerabilities[0].CISAKnownExploit)
require.NotNil(t, *withScores.Vulnerabilities[0].CISAKnownExploit)
require.Equal(t, **withScores.Vulnerabilities[0].CISAKnownExploit, true)
if tC.hasPublishedDate {
require.NotNil(t, withScores.Vulnerabilities[0].CVEPublished)
require.NotNil(t, *withScores.Vulnerabilities[0].CVEPublished)
require.Equal(t, (**withScores.Vulnerabilities[0].CVEPublished), now)
}
}
} else {
require.Empty(t, withoutScores.Vulnerabilities)
}
} else {
require.Empty(t, withoutScores.Vulnerabilities)
}
}
})
@ -2301,7 +2317,7 @@ func testDeleteOutOfDateVulnerabilities(t *testing.T, ds *Datastore) {
err = ds.DeleteOutOfDateVulnerabilities(ctx, fleet.NVDSource, 2*time.Hour)
require.NoError(t, err)
storedSoftware, err := ds.SoftwareByID(ctx, host.Software[0].ID, false)
storedSoftware, err := ds.SoftwareByID(ctx, host.Software[0].ID, nil, false)
require.NoError(t, err)
require.Equal(t, 1, len(storedSoftware.Vulnerabilities))
require.Equal(t, "CVE-2023-001", storedSoftware.Vulnerabilities[0].CVE)
@ -2352,7 +2368,7 @@ func testDeleteSoftwareCPEs(t *testing.T, ds *Datastore) {
require.NoError(t, err)
test.ElementsMatchSkipID(t, cpes[1:], storedCPEs)
storedSoftware, err := ds.SoftwareByID(ctx, cpes[0].SoftwareID, false)
storedSoftware, err := ds.SoftwareByID(ctx, cpes[0].SoftwareID, nil, false)
require.NoError(t, err)
require.Empty(t, storedSoftware.GenerateCPE)
})

View File

@ -12,7 +12,7 @@ import (
"github.com/jmoiron/sqlx"
)
func (ds *Datastore) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) {
func (ds *Datastore) SoftwareTitleByID(ctx context.Context, id uint, teamID *uint) (*fleet.SoftwareTitle, error) {
const selectSoftwareTitleStmt = `
SELECT
st.id,
@ -24,17 +24,22 @@ SELECT
FROM software_titles st
JOIN software_titles_host_counts sthc ON sthc.software_title_id = st.id
WHERE st.id = ?
AND sthc.team_id = 0
AND sthc.team_id = ?
AND sthc.hosts_count > 0
`
teamIDVal := uint(0)
if teamID != nil {
teamIDVal = *teamID
}
var title fleet.SoftwareTitle
if err := sqlx.GetContext(ctx, ds.reader(ctx), &title, selectSoftwareTitleStmt, id); err != nil {
if err := sqlx.GetContext(ctx, ds.reader(ctx), &title, selectSoftwareTitleStmt, id, teamIDVal); err != nil {
if err == sql.ErrNoRows {
return nil, notFound("SoftwareTitle").WithID(id)
}
return nil, ctxerr.Wrap(ctx, err, "get software title")
}
selectSoftwareVersionsStmt, args, err := selectSoftwareVersionsSQL([]uint{id}, 0, true)
selectSoftwareVersionsStmt, args, err := selectSoftwareVersionsSQL([]uint{id}, teamIDVal, true)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "building versions statement")
}

View File

@ -283,6 +283,15 @@ func (ds *Datastore) TeamsSummary(ctx context.Context) ([]*fleet.TeamSummary, er
return teamsSummary, nil
}
func (ds *Datastore) TeamExists(ctx context.Context, teamID uint) (bool, error) {
var exists bool
err := ds.writer(ctx).GetContext(ctx, &exists, "SELECT EXISTS(SELECT 1 FROM teams WHERE id = ?)", teamID)
if err != nil {
return false, ctxerr.Wrap(ctx, err, "team exists")
}
return exists, nil
}
func (ds *Datastore) SearchTeams(ctx context.Context, filter fleet.TeamFilter, matchQuery string, omit ...uint) ([]*fleet.Team, error) {
sql := fmt.Sprintf(`
SELECT *,

View File

@ -456,12 +456,14 @@ type Datastore interface {
// DeleteIntegrationsFromTeams deletes integrations used by teams, as they
// are being deleted from the global configuration.
DeleteIntegrationsFromTeams(ctx context.Context, deletedIntgs Integrations) error
// TeamExists returns true if a team with the given id exists.
TeamExists(ctx context.Context, teamID uint) (bool, error)
///////////////////////////////////////////////////////////////////////////////
// Software Titles
ListSoftwareTitles(ctx context.Context, opt SoftwareTitleListOptions) ([]SoftwareTitle, int, *PaginationMetadata, error)
SoftwareTitleByID(ctx context.Context, id uint) (*SoftwareTitle, error)
SoftwareTitleByID(ctx context.Context, id uint, teamID *uint) (*SoftwareTitle, error)
///////////////////////////////////////////////////////////////////////////////
// SoftwareStore
@ -484,7 +486,7 @@ type Datastore interface {
// case it will return true) or if a matching record already exists it will update its
// updated_at timestamp (in which case it will return false).
InsertSoftwareVulnerability(ctx context.Context, vuln SoftwareVulnerability, source VulnerabilitySource) (bool, error)
SoftwareByID(ctx context.Context, id uint, includeCVEScores bool) (*Software, error)
SoftwareByID(ctx context.Context, id uint, teamID *uint, includeCVEScores bool) (*Software, error)
// ListSoftwareByHostIDShort lists software by host ID, but does not include CPEs or vulnerabilites.
// It is meant to be used when only minimal software fields are required eg when updating host software.
ListSoftwareByHostIDShort(ctx context.Context, hostID uint) ([]Software, error)

View File

@ -596,14 +596,14 @@ type Service interface {
// Software
ListSoftware(ctx context.Context, opt SoftwareListOptions) ([]Software, *PaginationMetadata, error)
SoftwareByID(ctx context.Context, id uint, includeCVEScores bool) (*Software, error)
SoftwareByID(ctx context.Context, id uint, teamID *uint, includeCVEScores bool) (*Software, error)
CountSoftware(ctx context.Context, opt SoftwareListOptions) (int, error)
// /////////////////////////////////////////////////////////////////////////////
// Software Titles
ListSoftwareTitles(ctx context.Context, opt SoftwareTitleListOptions) ([]SoftwareTitle, int, *PaginationMetadata, error)
SoftwareTitleByID(ctx context.Context, id uint) (*SoftwareTitle, error)
SoftwareTitleByID(ctx context.Context, id uint, teamID *uint) (*SoftwareTitle, error)
// /////////////////////////////////////////////////////////////////////////////
// Vulnerabilities

View File

@ -350,9 +350,11 @@ type TeamEnrollSecretsFunc func(ctx context.Context, teamID uint) ([]*fleet.Enro
type DeleteIntegrationsFromTeamsFunc func(ctx context.Context, deletedIntgs fleet.Integrations) error
type TeamExistsFunc func(ctx context.Context, teamID uint) (bool, error)
type ListSoftwareTitlesFunc func(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error)
type SoftwareTitleByIDFunc func(ctx context.Context, id uint) (*fleet.SoftwareTitle, error)
type SoftwareTitleByIDFunc func(ctx context.Context, id uint, teamID *uint) (*fleet.SoftwareTitle, error)
type ListSoftwareForVulnDetectionFunc func(ctx context.Context, hostID uint) ([]fleet.Software, error)
@ -370,7 +372,7 @@ type ListSoftwareCPEsFunc func(ctx context.Context) ([]fleet.SoftwareCPE, error)
type InsertSoftwareVulnerabilityFunc func(ctx context.Context, vuln fleet.SoftwareVulnerability, source fleet.VulnerabilitySource) (bool, error)
type SoftwareByIDFunc func(ctx context.Context, id uint, includeCVEScores bool) (*fleet.Software, error)
type SoftwareByIDFunc func(ctx context.Context, id uint, teamID *uint, includeCVEScores bool) (*fleet.Software, error)
type ListSoftwareByHostIDShortFunc func(ctx context.Context, hostID uint) ([]fleet.Software, error)
@ -1331,6 +1333,9 @@ type DataStore struct {
DeleteIntegrationsFromTeamsFunc DeleteIntegrationsFromTeamsFunc
DeleteIntegrationsFromTeamsFuncInvoked bool
TeamExistsFunc TeamExistsFunc
TeamExistsFuncInvoked bool
ListSoftwareTitlesFunc ListSoftwareTitlesFunc
ListSoftwareTitlesFuncInvoked bool
@ -3219,6 +3224,13 @@ func (s *DataStore) DeleteIntegrationsFromTeams(ctx context.Context, deletedIntg
return s.DeleteIntegrationsFromTeamsFunc(ctx, deletedIntgs)
}
func (s *DataStore) TeamExists(ctx context.Context, teamID uint) (bool, error) {
s.mu.Lock()
s.TeamExistsFuncInvoked = true
s.mu.Unlock()
return s.TeamExistsFunc(ctx, teamID)
}
func (s *DataStore) ListSoftwareTitles(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) {
s.mu.Lock()
s.ListSoftwareTitlesFuncInvoked = true
@ -3226,11 +3238,11 @@ func (s *DataStore) ListSoftwareTitles(ctx context.Context, opt fleet.SoftwareTi
return s.ListSoftwareTitlesFunc(ctx, opt)
}
func (s *DataStore) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) {
func (s *DataStore) SoftwareTitleByID(ctx context.Context, id uint, teamID *uint) (*fleet.SoftwareTitle, error) {
s.mu.Lock()
s.SoftwareTitleByIDFuncInvoked = true
s.mu.Unlock()
return s.SoftwareTitleByIDFunc(ctx, id)
return s.SoftwareTitleByIDFunc(ctx, id, teamID)
}
func (s *DataStore) ListSoftwareForVulnDetection(ctx context.Context, hostID uint) ([]fleet.Software, error) {
@ -3289,11 +3301,11 @@ func (s *DataStore) InsertSoftwareVulnerability(ctx context.Context, vuln fleet.
return s.InsertSoftwareVulnerabilityFunc(ctx, vuln, source)
}
func (s *DataStore) SoftwareByID(ctx context.Context, id uint, includeCVEScores bool) (*fleet.Software, error) {
func (s *DataStore) SoftwareByID(ctx context.Context, id uint, teamID *uint, includeCVEScores bool) (*fleet.Software, error) {
s.mu.Lock()
s.SoftwareByIDFuncInvoked = true
s.mu.Unlock()
return s.SoftwareByIDFunc(ctx, id, includeCVEScores)
return s.SoftwareByIDFunc(ctx, id, teamID, includeCVEScores)
}
func (s *DataStore) ListSoftwareByHostIDShort(ctx context.Context, hostID uint) ([]fleet.Software, error) {

View File

@ -93,7 +93,7 @@ func listHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Servi
if id == nil {
id = req.Opts.SoftwareIDFilter
}
software, err = svc.SoftwareByID(ctx, *id, false)
software, err = svc.SoftwareByID(ctx, *id, req.Opts.TeamFilter, false)
if err != nil {
return listHostsResponse{Err: err}, nil
}
@ -103,7 +103,7 @@ func listHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Servi
if req.Opts.SoftwareTitleIDFilter != nil {
var err error
softwareTitle, err = svc.SoftwareTitleByID(ctx, *req.Opts.SoftwareTitleIDFilter)
softwareTitle, err = svc.SoftwareTitleByID(ctx, *req.Opts.SoftwareTitleIDFilter, nil)
if err != nil {
return listHostsResponse{Err: err}, nil
}
@ -1905,22 +1905,35 @@ func getOSVersionEndpoint(ctx context.Context, request interface{}, svc fleet.Se
if err != nil {
return getOSVersionResponse{Err: err}, nil
}
if osVersion == nil {
osVersion = &fleet.OSVersion{}
}
return getOSVersionResponse{CountsUpdatedAt: updateTime, OSVersion: osVersion}, nil
}
func (svc *Service) OSVersion(ctx context.Context, osID uint, teamID *uint, includeCVSS bool) (*fleet.OSVersion, *time.Time, error) {
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
if err := svc.authz.Authorize(ctx, &fleet.Host{TeamID: teamID}, fleet.ActionList); err != nil {
return nil, nil, err
}
if teamID != nil {
exists, err := svc.ds.TeamExists(ctx, *teamID)
if err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "checking if team exists")
} else if !exists {
return nil, nil, authz.ForbiddenWithInternal("team does not exist", nil, nil, nil)
}
}
osVersion, updateTime, err := svc.ds.OSVersion(ctx, osID, teamID)
if err != nil {
return nil, nil, err
}
if err = svc.populateOSVersionDetails(ctx, osVersion, includeCVSS); err != nil {
return nil, nil, err
if osVersion != nil {
if err = svc.populateOSVersionDetails(ctx, osVersion, includeCVSS); err != nil {
return nil, nil, err
}
}
return osVersion, updateTime, nil

View File

@ -59,7 +59,7 @@ func (s *integrationTestSuite) TearDownTest() {
func TestIntegrations(t *testing.T) {
testingSuite := new(integrationTestSuite)
testingSuite.s = &testingSuite.Suite
testingSuite.withServer.s = &testingSuite.Suite
suite.Run(t, testingSuite)
}
@ -5928,12 +5928,12 @@ func (s *integrationTestSuite) TestListSoftwareAndSoftwareDetails() {
require.NoError(t, s.ds.AddHostsToTeam(context.Background(), &tm.ID, []uint{hosts[19].ID, hosts[18].ID, hosts[17].ID}))
expectedTeamVersionsCount := 3
assertSoftwareDetails := func(expectedSoftware []fleet.Software) {
assertSoftwareDetails := func(expectedSoftware []fleet.Software, team string) {
// this is just a basic sanity check of the software details endpoints and doesn't test all of the
// fields that may be present in the response (e.g., vulnerabilities)
for _, sw := range expectedSoftware {
var detailsResp getSoftwareResponse
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/software/%d", sw.ID), nil, http.StatusOK, &detailsResp)
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/software/%d", sw.ID), nil, http.StatusOK, &detailsResp, "team_id", team)
assert.Equal(t, sw.ID, detailsResp.Software.ID)
assert.Equal(t, sw.Name, detailsResp.Software.Name)
assert.Equal(t, sw.Version, detailsResp.Software.Version)
@ -5941,7 +5941,7 @@ func (s *integrationTestSuite) TestListSoftwareAndSoftwareDetails() {
assert.Equal(t, sw.Browser, detailsResp.Software.Browser)
detailsResp = getSoftwareResponse{}
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/software/versions/%d", sw.ID), nil, http.StatusOK, &detailsResp)
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/software/versions/%d", sw.ID), nil, http.StatusOK, &detailsResp, "team_id", team)
assert.Equal(t, sw.ID, detailsResp.Software.ID)
assert.Equal(t, sw.Name, detailsResp.Software.Name)
assert.Equal(t, sw.Version, detailsResp.Software.Version)
@ -5950,7 +5950,7 @@ func (s *integrationTestSuite) TestListSoftwareAndSoftwareDetails() {
}
}
assertResp := func(resp listSoftwareResponse, want []fleet.Software, ts time.Time, counts ...int) {
assertResp := func(resp listSoftwareResponse, want []fleet.Software, ts time.Time, team string, counts ...int) {
require.Len(t, resp.Software, len(want))
for i := range resp.Software {
wantID, gotID := want[i].ID, resp.Software[i].ID
@ -5972,10 +5972,12 @@ func (s *integrationTestSuite) TestListSoftwareAndSoftwareDetails() {
require.NotNil(t, resp.CountsUpdatedAt)
assert.WithinDuration(t, ts, *resp.CountsUpdatedAt, time.Second)
}
assertSoftwareDetails(resp.Software)
assertSoftwareDetails(resp.Software, team)
}
assertVersionsResp := func(resp listSoftwareVersionsResponse, want []fleet.Software, ts time.Time, swCount int, hostCounts ...int) {
assertVersionsResp := func(
resp listSoftwareVersionsResponse, want []fleet.Software, ts time.Time, team string, swCount int, hostCounts ...int,
) {
require.Equal(t, swCount, resp.Count)
require.Len(t, resp.Software, len(want))
for i := range resp.Software {
@ -5998,24 +6000,31 @@ func (s *integrationTestSuite) TestListSoftwareAndSoftwareDetails() {
require.NotNil(t, resp.CountsUpdatedAt)
assert.WithinDuration(t, ts, *resp.CountsUpdatedAt, time.Second)
}
assertSoftwareDetails(resp.Software)
assertSoftwareDetails(resp.Software, team)
}
// no software host counts have been calculated yet, so this returns nothing
var lsResp listSoftwareResponse
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, nil, time.Time{})
assertResp(lsResp, nil, time.Time{}, "")
var versResp listSoftwareVersionsResponse
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, nil, time.Time{}, 0)
assertVersionsResp(versResp, nil, time.Time{}, "", 0)
// same with a team filter
teamStr := fmt.Sprintf("%d", tm.ID)
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "order_key", "hosts_count", "order_direction", "desc", "team_id", fmt.Sprintf("%d", tm.ID))
assertResp(lsResp, nil, time.Time{})
s.DoJSON(
"GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "order_key", "hosts_count", "order_direction", "desc", "team_id",
teamStr,
)
assertResp(lsResp, nil, time.Time{}, teamStr)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "order_key", "hosts_count", "order_direction", "desc", "team_id", fmt.Sprintf("%d", tm.ID))
assertVersionsResp(versResp, nil, time.Time{}, 0)
s.DoJSON(
"GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "order_key", "hosts_count", "order_direction", "desc",
"team_id", teamStr,
)
assertVersionsResp(versResp, nil, time.Time{}, teamStr, 0)
// calculate hosts counts
hostsCountTs := time.Now().UTC()
@ -6024,106 +6033,133 @@ func (s *integrationTestSuite) TestListSoftwareAndSoftwareDetails() {
// now the list software endpoint returns the software, get the first page without vulns
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "0", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[19], sws[18], sws[17], sws[16], sws[15]}, hostsCountTs, 20, 19, 18, 17, 16)
assertResp(lsResp, []fleet.Software{sws[19], sws[18], sws[17], sws[16], sws[15]}, hostsCountTs, "", 20, 19, 18, 17, 16)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "5", "page", "0", "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, []fleet.Software{sws[19], sws[18], sws[17], sws[16], sws[15]}, hostsCountTs, len(sws), 20, 19, 18, 17, 16)
assertVersionsResp(
versResp, []fleet.Software{sws[19], sws[18], sws[17], sws[16], sws[15]}, hostsCountTs, "", len(sws), 20, 19, 18, 17, 16,
)
require.False(t, versResp.Meta.HasPreviousResults)
require.True(t, versResp.Meta.HasNextResults)
// second page (page=1)
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "1", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[14], sws[13], sws[12], sws[11], sws[10]}, hostsCountTs, 15, 14, 13, 12, 11)
assertResp(lsResp, []fleet.Software{sws[14], sws[13], sws[12], sws[11], sws[10]}, hostsCountTs, "", 15, 14, 13, 12, 11)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "5", "page", "1", "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, []fleet.Software{sws[14], sws[13], sws[12], sws[11], sws[10]}, hostsCountTs, len(sws), 15, 14, 13, 12, 11)
assertVersionsResp(
versResp, []fleet.Software{sws[14], sws[13], sws[12], sws[11], sws[10]}, hostsCountTs, "", len(sws), 15, 14, 13, 12, 11,
)
require.True(t, versResp.Meta.HasPreviousResults)
require.True(t, versResp.Meta.HasNextResults)
// third page (page=2)
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "2", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, 10, 9, 8, 7, 6)
assertResp(lsResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, "", 10, 9, 8, 7, 6)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "5", "page", "2", "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, len(sws), 10, 9, 8, 7, 6)
assertVersionsResp(versResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, "", len(sws), 10, 9, 8, 7, 6)
require.True(t, versResp.Meta.HasPreviousResults)
require.True(t, versResp.Meta.HasNextResults)
// last page (page=3)
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "3", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, 5, 4, 3, 2, 1)
assertResp(lsResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, "", 5, 4, 3, 2, 1)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "5", "page", "3", "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, len(sws), 5, 4, 3, 2, 1)
assertVersionsResp(versResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, "", len(sws), 5, 4, 3, 2, 1)
require.True(t, versResp.Meta.HasPreviousResults)
require.False(t, versResp.Meta.HasNextResults)
// past the end
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "5", "page", "4", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, nil, time.Time{})
assertResp(lsResp, nil, time.Time{}, "")
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "5", "page", "4", "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, nil, time.Time{}, len(sws))
assertVersionsResp(versResp, nil, time.Time{}, "", len(sws))
// no explicit sort order, defaults to hosts_count DESC
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "2", "page", "0")
assertResp(lsResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, 20, 19)
assertResp(lsResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, "", 20, 19)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "2", "page", "0")
assertVersionsResp(versResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, len(sws), 20, 19)
assertVersionsResp(versResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, "", len(sws), 20, 19)
// hosts_count ascending
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "3", "page", "0", "order_key", "hosts_count", "order_direction", "asc")
assertResp(lsResp, []fleet.Software{sws[0], sws[1], sws[2]}, hostsCountTs, 1, 2, 3)
assertResp(lsResp, []fleet.Software{sws[0], sws[1], sws[2]}, hostsCountTs, "", 1, 2, 3)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "3", "page", "0", "order_key", "hosts_count", "order_direction", "asc")
assertVersionsResp(versResp, []fleet.Software{sws[0], sws[1], sws[2]}, hostsCountTs, len(sws), 1, 2, 3)
assertVersionsResp(versResp, []fleet.Software{sws[0], sws[1], sws[2]}, hostsCountTs, "", len(sws), 1, 2, 3)
// vulnerable software only
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "vulnerable", "true", "per_page", "5", "page", "0", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, 10, 9, 8, 7, 6)
assertResp(lsResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, "", 10, 9, 8, 7, 6)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "vulnerable", "true", "per_page", "5", "page", "0", "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, expectedVulnVersionsCount, 10, 9, 8, 7, 6)
assertVersionsResp(
versResp, []fleet.Software{sws[9], sws[8], sws[7], sws[6], sws[5]}, hostsCountTs, "", expectedVulnVersionsCount, 10, 9, 8, 7, 6,
)
// vulnerable software only, next page
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "vulnerable", "true", "per_page", "5", "page", "1", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, 5, 4, 3, 2, 1)
assertResp(lsResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, "", 5, 4, 3, 2, 1)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "vulnerable", "true", "per_page", "5", "page", "1", "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, expectedVulnVersionsCount, 5, 4, 3, 2, 1)
assertVersionsResp(
versResp, []fleet.Software{sws[4], sws[3], sws[2], sws[1], sws[0]}, hostsCountTs, "", expectedVulnVersionsCount, 5, 4, 3, 2, 1,
)
// vulnerable software only, past last page
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "vulnerable", "true", "per_page", "5", "page", "2", "order_key", "hosts_count", "order_direction", "desc")
assertResp(lsResp, nil, time.Time{})
assertResp(lsResp, nil, time.Time{}, "")
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "vulnerable", "true", "per_page", "5", "page", "2", "order_key", "hosts_count", "order_direction", "desc")
assertVersionsResp(versResp, nil, time.Time{}, expectedVulnVersionsCount)
assertVersionsResp(versResp, nil, time.Time{}, "", expectedVulnVersionsCount)
// filter by the team, 2 by page
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "2", "page", "0", "order_key", "hosts_count", "order_direction", "desc", "team_id", fmt.Sprintf("%d", tm.ID))
assertResp(lsResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, 3, 2)
s.DoJSON(
"GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "2", "page", "0", "order_key", "hosts_count",
"order_direction", "desc", "team_id", teamStr,
)
assertResp(lsResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, teamStr, 3, 2)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "2", "page", "0", "order_key", "hosts_count", "order_direction", "desc", "team_id", fmt.Sprintf("%d", tm.ID))
assertVersionsResp(versResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, expectedTeamVersionsCount, 3, 2)
s.DoJSON(
"GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "2", "page", "0", "order_key",
"hosts_count", "order_direction", "desc", "team_id", teamStr,
)
assertVersionsResp(versResp, []fleet.Software{sws[19], sws[18]}, hostsCountTs, teamStr, expectedTeamVersionsCount, 3, 2)
// filter by the team, 2 by page, next page
lsResp = listSoftwareResponse{}
s.DoJSON("GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "2", "page", "1", "order_key", "hosts_count", "order_direction", "desc", "team_id", fmt.Sprintf("%d", tm.ID))
assertResp(lsResp, []fleet.Software{sws[17]}, hostsCountTs, 1)
s.DoJSON(
"GET", "/api/latest/fleet/software", nil, http.StatusOK, &lsResp, "per_page", "2", "page", "1", "order_key", "hosts_count",
"order_direction", "desc", "team_id", teamStr,
)
assertResp(lsResp, []fleet.Software{sws[17]}, hostsCountTs, teamStr, 1)
versResp = listSoftwareVersionsResponse{}
s.DoJSON("GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "2", "page", "1", "order_key", "hosts_count", "order_direction", "desc", "team_id", fmt.Sprintf("%d", tm.ID))
assertVersionsResp(versResp, []fleet.Software{sws[17]}, hostsCountTs, expectedTeamVersionsCount, 1)
s.DoJSON(
"GET", "/api/latest/fleet/software/versions", nil, http.StatusOK, &versResp, "per_page", "2", "page", "1", "order_key",
"hosts_count", "order_direction", "desc", "team_id", teamStr,
)
assertVersionsResp(versResp, []fleet.Software{sws[17]}, hostsCountTs, teamStr, expectedTeamVersionsCount, 1)
// Invalid software team
detailsResp := getSoftwareResponse{}
s.DoJSON(
"GET", fmt.Sprintf("/api/latest/fleet/software/versions/%d", versResp.Software[0].ID), nil, http.StatusForbidden, &detailsResp,
"team_id", "999999",
)
}
func (s *integrationTestSuite) TestChangeUserEmail() {

View File

@ -3401,6 +3401,27 @@ func (s *integrationEnterpriseTestSuite) TestOSVersions() {
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusOK, &osVersionResp)
require.Equal(t, &osVersionsResp.OSVersions[0], osVersionResp.OSVersion)
// OS versions with invalid team
s.DoJSON(
"GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusForbidden, &osVersionResp, "team_id",
"99999",
)
// Create team and ask for the OS versions from the team (with no hosts) -- should get 404.
tr := teamResponse{}
s.DoJSON(
"POST", "/api/latest/fleet/teams", createTeamRequest{
TeamPayload: fleet.TeamPayload{
Name: ptr.String("os_versions_team"),
},
}, http.StatusOK, &tr,
)
osVersionResp = getOSVersionResponse{}
s.DoJSON(
"GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusNotFound, &osVersionResp, "team_id",
fmt.Sprintf("%d", tr.Team.ID),
)
// return empty json if UpdateOSVersions cron hasn't run yet for new team
team, err := s.ds.NewTeam(context.Background(), &fleet.Team{Name: "new team"})
require.NoError(t, err)
@ -6352,7 +6373,7 @@ func (s *integrationEnterpriseTestSuite) TestAllSoftwareTitles() {
// create a couple of teams and add tmHost to one
team1, err := s.ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team1"})
require.NoError(t, err)
_, err = s.ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team2"})
team2, err := s.ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team2"})
require.NoError(t, err)
require.NoError(t, s.ds.AddHostsToTeam(ctx, &team1.ID, []uint{tmHost.ID}))
@ -6697,6 +6718,26 @@ func (s *integrationEnterpriseTestSuite) TestAllSoftwareTitles() {
},
}, []fleet.SoftwareTitle{*stResp.SoftwareTitle})
// valid title for team
stResp = getSoftwareTitleResponse{}
s.DoJSON(
"GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", fooTitle.ID), getSoftwareTitleRequest{}, http.StatusOK, &stResp,
"team_id", fmt.Sprintf("%d", team1.ID),
)
softwareTitlesMatch(
[]fleet.SoftwareTitle{
{
Name: "foo",
Source: "homebrew",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{Version: "0.0.1", Vulnerabilities: nil, HostsCount: ptr.Uint(1)},
},
},
}, []fleet.SoftwareTitle{*stResp.SoftwareTitle},
)
// find the ID of "bar"
resp = listSoftwareTitlesResponse{}
s.DoJSON(
@ -6728,6 +6769,63 @@ func (s *integrationEnterpriseTestSuite) TestAllSoftwareTitles() {
},
},
}, []fleet.SoftwareTitle{*stResp.SoftwareTitle})
// invalid title for team
stResp = getSoftwareTitleResponse{}
s.DoJSON(
"GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", barTitle.ID), getSoftwareTitleRequest{}, http.StatusNotFound, &stResp,
"team_id", fmt.Sprintf("%d", team1.ID),
)
// add bar tmHost
software = []fleet.Software{
{Name: "bar", Version: "0.0.4", Source: "apps"},
}
_, err = s.ds.UpdateHostSoftware(context.Background(), tmHost.ID, software)
require.NoError(t, err)
// calculate hosts counts
hostsCountTs = time.Now().UTC()
require.NoError(t, s.ds.SyncHostsSoftware(context.Background(), hostsCountTs))
require.NoError(t, s.ds.ReconcileSoftwareTitles(ctx))
require.NoError(t, s.ds.SyncHostsSoftwareTitles(ctx, hostsCountTs))
// valid title with vulnerabilities
stResp = getSoftwareTitleResponse{}
s.DoJSON(
"GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", barTitle.ID), getSoftwareTitleRequest{}, http.StatusOK, &stResp,
"team_id", fmt.Sprintf("%d", team1.ID),
)
softwareTitlesMatch(
[]fleet.SoftwareTitle{
{
Name: "bar",
Source: "apps",
VersionsCount: 1,
HostsCount: 1,
Versions: []fleet.SoftwareVersion{
{
Version: "0.0.4",
Vulnerabilities: &fleet.SliceString{"cve-123-123-132"},
HostsCount: ptr.Uint(1),
},
},
},
}, []fleet.SoftwareTitle{*stResp.SoftwareTitle},
)
// Team without hosts
s.DoJSON(
"GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", barTitle.ID), getSoftwareTitleRequest{}, http.StatusNotFound, &stResp,
"team_id", fmt.Sprintf("%d", team2.ID),
)
// Non-existent team
s.DoJSON(
"GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", barTitle.ID), getSoftwareTitleRequest{}, http.StatusForbidden, &stResp,
"team_id", "99999",
)
}
func (s *integrationEnterpriseTestSuite) TestLockUnlockWindowsLinux() {

View File

@ -2,6 +2,8 @@ package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -119,7 +121,8 @@ func (svc *Service) ListSoftware(ctx context.Context, opt fleet.SoftwareListOpti
/////////////////////////////////////////////////////////////////////////////////
type getSoftwareRequest struct {
ID uint `url:"id"`
ID uint `url:"id"`
TeamID *uint `query:"team_id,optional"`
}
type getSoftwareResponse struct {
@ -132,7 +135,7 @@ func (r getSoftwareResponse) error() error { return r.Err }
func getSoftwareEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*getSoftwareRequest)
software, err := svc.SoftwareByID(ctx, req.ID, false)
software, err := svc.SoftwareByID(ctx, req.ID, req.TeamID, false)
if err != nil {
return getSoftwareResponse{Err: err}, nil
}
@ -140,12 +143,20 @@ func getSoftwareEndpoint(ctx context.Context, request interface{}, svc fleet.Ser
return getSoftwareResponse{Software: software}, nil
}
func (svc *Service) SoftwareByID(ctx context.Context, id uint, includeCVEScores bool) (*fleet.Software, error) {
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
func (svc *Service) SoftwareByID(ctx context.Context, id uint, teamID *uint, includeCVEScores bool) (*fleet.Software, error) {
if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{TeamID: teamID}, fleet.ActionRead); err != nil {
return nil, err
}
software, err := svc.ds.SoftwareByID(ctx, id, includeCVEScores)
if teamID != nil {
exists, err := svc.ds.TeamExists(ctx, *teamID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "checking if team exists")
} else if !exists {
return nil, authz.ForbiddenWithInternal("team does not exist", nil, nil, nil)
}
}
software, err := svc.ds.SoftwareByID(ctx, id, teamID, includeCVEScores)
if err != nil {
return nil, err
}

View File

@ -61,6 +61,10 @@ func TestServiceSoftwareInventoryAuth(t *testing.T) {
ds.CountSoftwareFunc = func(ctx context.Context, opt fleet.SoftwareListOptions) (int, error) {
return 0, nil
}
ds.SoftwareByIDFunc = func(ctx context.Context, id uint, teamID *uint, includeCVEScores bool) (*fleet.Software, error) {
return &fleet.Software{}, nil
}
ds.TeamExistsFunc = func(ctx context.Context, teamID uint) (bool, error) { return true, nil }
svc, ctx := newTestService(t, ds, nil, nil)
for _, tc := range []struct {
@ -191,6 +195,13 @@ func TestServiceSoftwareInventoryAuth(t *testing.T) {
TeamID: ptr.Uint(1),
})
checkAuthErr(t, tc.shouldFailTeamRead, err)
// Get software by ID.
_, err = svc.SoftwareByID(ctx, 1, nil, false)
checkAuthErr(t, tc.shouldFailGlobalRead, err)
_, err = svc.SoftwareByID(ctx, 1, ptr.Uint(1), false)
checkAuthErr(t, tc.shouldFailTeamRead, err)
})
}
}

View File

@ -2,6 +2,7 @@ package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/authz"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
@ -89,7 +90,8 @@ func (svc *Service) ListSoftwareTitles(
/////////////////////////////////////////////////////////////////////////////////
type getSoftwareTitleRequest struct {
ID uint `url:"id"`
ID uint `url:"id"`
TeamID *uint `query:"team_id,optional"`
}
type getSoftwareTitleResponse struct {
@ -102,7 +104,7 @@ func (r getSoftwareTitleResponse) error() error { return r.Err }
func getSoftwareTitleEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*getSoftwareTitleRequest)
software, err := svc.SoftwareTitleByID(ctx, req.ID)
software, err := svc.SoftwareTitleByID(ctx, req.ID, req.TeamID)
if err != nil {
return getSoftwareTitleResponse{Err: err}, nil
}
@ -110,14 +112,20 @@ func getSoftwareTitleEndpoint(ctx context.Context, request interface{}, svc flee
return getSoftwareTitleResponse{SoftwareTitle: software}, nil
}
func (svc *Service) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) {
// TODO: this is the autorization we do for GET /software, does it look right?
// checking with product here: https://github.com/fleetdm/fleet/issues/14674#issuecomment-1841395788
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
func (svc *Service) SoftwareTitleByID(ctx context.Context, id uint, teamID *uint) (*fleet.SoftwareTitle, error) {
if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{TeamID: teamID}, fleet.ActionRead); err != nil {
return nil, err
}
software, err := svc.ds.SoftwareTitleByID(ctx, id)
if teamID != nil {
exists, err := svc.ds.TeamExists(ctx, *teamID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "checking if team exists")
} else if !exists {
return nil, authz.ForbiddenWithInternal("team does not exist", nil, nil, nil)
}
}
software, err := svc.ds.SoftwareTitleByID(ctx, id, teamID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting software title by id")
}

View File

@ -18,9 +18,10 @@ func TestServiceSoftwareTitlesAuth(t *testing.T) {
ds.ListSoftwareTitlesFunc = func(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) {
return []fleet.SoftwareTitle{}, 0, &fleet.PaginationMetadata{}, nil
}
ds.SoftwareTitleByIDFunc = func(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) {
ds.SoftwareTitleByIDFunc = func(ctx context.Context, id uint, teamID *uint) (*fleet.SoftwareTitle, error) {
return &fleet.SoftwareTitle{}, nil
}
ds.TeamExistsFunc = func(ctx context.Context, teamID uint) (bool, error) { return true, nil }
svc, ctx := newTestService(t, ds, nil, nil)
@ -154,8 +155,13 @@ func TestServiceSoftwareTitlesAuth(t *testing.T) {
}
// Get a software title
_, err = svc.SoftwareTitleByID(ctx, 1)
checkAuthErr(t, false, err)
_, err = svc.SoftwareTitleByID(ctx, 1, nil)
checkAuthErr(t, tc.shouldFailGlobalRead, err)
// Get a software title for a team
_, err = svc.SoftwareTitleByID(ctx, 1, ptr.Uint(1))
checkAuthErr(t, tc.shouldFailTeamRead, err)
})
}
}

View File

@ -84,7 +84,7 @@ func getStoredVulnerabilities(
ds fleet.Datastore,
softwareID uint,
) ([]fleet.SoftwareVulnerability, error) {
storedSoftware, err := ds.SoftwareByID(ctx, softwareID, false)
storedSoftware, err := ds.SoftwareByID(ctx, softwareID, nil, false)
if err != nil {
return nil, err
}

View File

@ -78,7 +78,7 @@ func TestAnalyzer(t *testing.T) {
t.Run("getStoredVulnerabilities", func(t *testing.T) {
t.Run("on error", func(t *testing.T) {
ds := new(mock.Store)
ds.SoftwareByIDFunc = func(ctx context.Context, id uint, includeCVEScores bool) (*fleet.Software, error) {
ds.SoftwareByIDFunc = func(ctx context.Context, id uint, teamID *uint, includeCVEScores bool) (*fleet.Software, error) {
return nil, errors.New("some error")
}