diff --git a/changes/17347-team-user-os-version-restrict b/changes/17347-team-user-os-version-restrict new file mode 100644 index 000000000..49d0bb6a6 --- /dev/null +++ b/changes/17347-team-user-os-version-restrict @@ -0,0 +1 @@ +For GET fleet/os_versions and GET fleet/os_versions/[id], team users no longer have access to os versions on hosts from other teams. diff --git a/cmd/fleet/serve_test.go b/cmd/fleet/serve_test.go index 0da6a384e..a2047430a 100644 --- a/cmd/fleet/serve_test.go +++ b/cmd/fleet/serve_test.go @@ -304,7 +304,9 @@ func TestCronVulnerabilitiesCreatesDatabasesPath(t *testing.T) { // we should not get this far before we see the directory being created return nil, errors.New("shouldn't happen") } - ds.OSVersionsFunc = func(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) { + ds.OSVersionsFunc = func( + ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string, + ) (*fleet.OSVersions, error) { return &fleet.OSVersions{}, nil } ds.SyncHostsSoftwareFunc = func(ctx context.Context, updatedAt time.Time) error { @@ -452,7 +454,9 @@ func TestScanVulnerabilities(t *testing.T) { ds.DeleteOutOfDateVulnerabilitiesFunc = func(ctx context.Context, source fleet.VulnerabilitySource, duration time.Duration) error { return nil } - ds.OSVersionsFunc = func(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) { + ds.OSVersionsFunc = func( + ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string, + ) (*fleet.OSVersions, error) { return &fleet.OSVersions{ CountsUpdatedAt: time.Now(), OSVersions: []fleet.OSVersion{ diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index f781dddc4..df2f4395b 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -4464,8 +4464,10 @@ func (ds *Datastore) UpdateHost(ctx context.Context, host *fleet.Host) error { ) } -func (ds *Datastore) OSVersion(ctx context.Context, osVersionID uint, teamID *uint) (*fleet.OSVersion, *time.Time, error) { - jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamID) +func (ds *Datastore) OSVersion(ctx context.Context, osVersionID uint, teamFilter *fleet.TeamFilter) ( + *fleet.OSVersion, *time.Time, error, +) { + jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamFilter) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil, notFound("OSVersion") @@ -4510,7 +4512,9 @@ func (ds *Datastore) OSVersion(ctx context.Context, osVersionID uint, teamID *ui // counts for the same macOS version on x86_64 and arm64 architectures are counted together. // Results can be filtered using the following optional criteria: team id, platform, or name and // version. Name cannot be used without version, and conversely, version cannot be used without name. -func (ds *Datastore) OSVersions(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) { +func (ds *Datastore) OSVersions( + ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string, +) (*fleet.OSVersions, error) { if name != nil && version == nil { return nil, errors.New("invalid usage: cannot filter by name without version") } @@ -4518,7 +4522,7 @@ func (ds *Datastore) OSVersions(ctx context.Context, teamID *uint, platform *str return nil, errors.New("invalid usage: cannot filter by version without name") } - jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamID) + jsonValue, updatedAt, err := ds.executeOSVersionQuery(ctx, teamFilter) if err != nil { return nil, err } @@ -4568,30 +4572,34 @@ func (ds *Datastore) OSVersions(ctx context.Context, teamID *uint, platform *str return res, nil } -func (ds *Datastore) executeOSVersionQuery(ctx context.Context, teamID *uint) (*json.RawMessage, time.Time, error) { +func (ds *Datastore) executeOSVersionQuery(ctx context.Context, teamFilter *fleet.TeamFilter) ( + *json.RawMessage, time.Time, error, +) { query := ` SELECT json_value, updated_at FROM aggregated_stats - WHERE - id = ? AND - global_stats = ? AND - type = ? + WHERE type = ? ` + args := []interface{}{aggregatedStatsTypeOSVersions} + switch { + case teamFilter != nil && teamFilter.TeamID != nil: + query += " AND id = ? AND global_stats = ?" + args = append(args, *teamFilter.TeamID, false) + case teamFilter != nil: + query += " AND " + ds.whereFilterGlobalOrTeamIDByTeamsWithSqlFilter( + *teamFilter, "global_stats = 1 AND id = 0", "global_stats = 0 AND id", + ) + default: + query += " AND id = ? AND global_stats = ?" + args = append(args, 0, true) + } var row struct { JSONValue *json.RawMessage `db:"json_value"` UpdatedAt time.Time `db:"updated_at"` } - - id := uint(0) - globalStats := true - if teamID != nil { - id = *teamID - globalStats = false - } - - err := sqlx.GetContext(ctx, ds.reader(ctx), &row, query, id, globalStats, aggregatedStatsTypeOSVersions) + err := sqlx.GetContext(ctx, ds.reader(ctx), &row, query, args...) if err != nil { if err == sql.ErrNoRows { return nil, time.Time{}, ctxerr.Wrap(ctx, notFound("OSVersion")) diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index fa699ecdc..d1dedf061 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -6275,7 +6275,8 @@ func testOSVersions(t *testing.T, ds *Datastore) { require.Equal(t, &expected[0], osVersion) // team 1 - osVersions, err = ds.OSVersions(ctx, &team1.ID, nil, nil, nil) + userAdmin := &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)} + osVersions, err = ds.OSVersions(ctx, &fleet.TeamFilter{TeamID: &team1.ID, User: userAdmin}, nil, nil, nil) require.NoError(t, err) expected = []fleet.OSVersion{ @@ -6284,16 +6285,25 @@ func testOSVersions(t *testing.T, ds *Datastore) { } require.Equal(t, expected, osVersions.OSVersions) - osVersion, _, err = ds.OSVersion(ctx, 5, &team1.ID) + osVersion, _, err = ds.OSVersion(ctx, 5, &fleet.TeamFilter{TeamID: &team1.ID}) require.NoError(t, err) require.Equal(t, &expected[0], osVersion) - osVersion, _, err = ds.OSVersion(ctx, 2, &team1.ID) + osVersion, _, err = ds.OSVersion(ctx, 2, &fleet.TeamFilter{TeamID: &team1.ID, User: userAdmin}) + require.NoError(t, err) + require.Equal(t, &expected[1], osVersion) + + userTeam1 := &fleet.User{Teams: []fleet.UserTeam{{Team: *team1, Role: fleet.RoleAdmin}}} + osVersions, err = ds.OSVersions(ctx, &fleet.TeamFilter{User: userTeam1}, nil, nil, nil) + require.NoError(t, err) + require.Equal(t, expected, osVersions.OSVersions) + + osVersion, _, err = ds.OSVersion(ctx, 2, &fleet.TeamFilter{User: userTeam1}) require.NoError(t, err) require.Equal(t, &expected[1], osVersion) // team 2 - osVersions, err = ds.OSVersions(ctx, &team2.ID, nil, nil, nil) + osVersions, err = ds.OSVersions(ctx, &fleet.TeamFilter{TeamID: &team2.ID}, nil, nil, nil) require.NoError(t, err) expected = []fleet.OSVersion{ @@ -6302,26 +6312,30 @@ func testOSVersions(t *testing.T, ds *Datastore) { } require.Equal(t, expected, osVersions.OSVersions) - osVersion, _, err = ds.OSVersion(ctx, 2, &team2.ID) + osVersion, _, err = ds.OSVersion(ctx, 2, &fleet.TeamFilter{TeamID: &team2.ID}) require.NoError(t, err) require.Equal(t, &expected[0], osVersion) - osVersion, _, err = ds.OSVersion(ctx, 3, &team2.ID) + osVersion, _, err = ds.OSVersion(ctx, 3, &fleet.TeamFilter{TeamID: &team2.ID}) require.NoError(t, err) require.Equal(t, &expected[1], osVersion) + // Wrong team + _, _, err = ds.OSVersion(ctx, 3, &fleet.TeamFilter{User: userTeam1}) + require.True(t, fleet.IsNotFound(err)) + // team 3 (no hosts assigned to team) - osVersions, err = ds.OSVersions(ctx, &team3.ID, nil, nil, nil) + osVersions, err = ds.OSVersions(ctx, &fleet.TeamFilter{TeamID: &team3.ID}, nil, nil, nil) require.NoError(t, err) expected = []fleet.OSVersion{} require.Equal(t, expected, osVersions.OSVersions) - osVersion, _, err = ds.OSVersion(ctx, 2, &team3.ID) + osVersion, _, err = ds.OSVersion(ctx, 2, &fleet.TeamFilter{TeamID: &team3.ID}) require.Error(t, err) require.Nil(t, osVersion) // non-existent team - _, err = ds.OSVersions(ctx, ptr.Uint(404), nil, nil, nil) + _, err = ds.OSVersions(ctx, &fleet.TeamFilter{TeamID: ptr.Uint(404)}, nil, nil, nil) require.Error(t, err) // new host with arm64 diff --git a/server/datastore/mysql/mysql.go b/server/datastore/mysql/mysql.go index 1c6dedef5..8c4d7c62c 100644 --- a/server/datastore/mysql/mysql.go +++ b/server/datastore/mysql/mysql.go @@ -889,6 +889,14 @@ func (ds *Datastore) whereFilterHostsByTeams(filter fleet.TeamFilter, hostKey st // filterTableAlias is the name/alias of the table to use in generating the // SQL. func (ds *Datastore) whereFilterGlobalOrTeamIDByTeams(filter fleet.TeamFilter, filterTableAlias string) string { + globalFilter := fmt.Sprintf("%s.team_id = 0", filterTableAlias) + teamIDFilter := fmt.Sprintf("%s.team_id", filterTableAlias) + return ds.whereFilterGlobalOrTeamIDByTeamsWithSqlFilter(filter, globalFilter, teamIDFilter) +} + +func (ds *Datastore) whereFilterGlobalOrTeamIDByTeamsWithSqlFilter( + filter fleet.TeamFilter, globalSqlFilter string, teamIDSqlFilter string, +) string { if filter.User == nil { // This is likely unintentional, however we would like to return no // results rather than panicking or returning some other error. At least @@ -897,9 +905,9 @@ func (ds *Datastore) whereFilterGlobalOrTeamIDByTeams(filter fleet.TeamFilter, f return "FALSE" } - defaultAllowClause := fmt.Sprintf("%s.team_id = 0", filterTableAlias) + defaultAllowClause := globalSqlFilter if filter.TeamID != nil { - defaultAllowClause = fmt.Sprintf("%s.team_id = %d", filterTableAlias, *filter.TeamID) + defaultAllowClause = fmt.Sprintf("%s = %d", teamIDSqlFilter, *filter.TeamID) } if filter.User.GlobalRole != nil { @@ -944,7 +952,7 @@ func (ds *Datastore) whereFilterGlobalOrTeamIDByTeams(filter fleet.TeamFilter, f return "FALSE" } - return fmt.Sprintf("%s.team_id IN (%s)", filterTableAlias, strings.Join(idStrs, ",")) + return fmt.Sprintf("%s IN (%s)", teamIDSqlFilter, strings.Join(idStrs, ",")) } // whereFilterTeams returns the appropriate condition to use in the WHERE diff --git a/server/datastore/mysql/vulnerabilities.go b/server/datastore/mysql/vulnerabilities.go index 97542dd7e..6b46da65c 100644 --- a/server/datastore/mysql/vulnerabilities.go +++ b/server/datastore/mysql/vulnerabilities.go @@ -108,7 +108,11 @@ func (ds *Datastore) Vulnerability(ctx context.Context, cve string, teamID *uint } func (ds *Datastore) OSVersionsByCVE(ctx context.Context, cve string, teamID *uint) (vos []*fleet.VulnerableOS, updatedAt time.Time, err error) { - osvs, err := ds.OSVersions(ctx, teamID, nil, nil, nil) + var teamFilter *fleet.TeamFilter + if teamID != nil { + teamFilter = &fleet.TeamFilter{TeamID: teamID} + } + osvs, err := ds.OSVersions(ctx, teamFilter, nil, nil, nil) if err != nil && !fleet.IsNotFound(err) { return nil, updatedAt, ctxerr.Wrap(ctx, err, "fetching team OS versions") } diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index d351e4ca2..245905812 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -315,10 +315,12 @@ type Datastore interface { GetMunkiIssue(ctx context.Context, munkiIssueID uint) (*MunkiIssue, error) GetMDMSolution(ctx context.Context, mdmID uint) (*MDMSolution, error) - OSVersions(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*OSVersions, error) + OSVersions(ctx context.Context, teamFilter *TeamFilter, platform *string, name *string, version *string) (*OSVersions, error) OSVersionsByCVE(ctx context.Context, cve string, teamID *uint) ([]*VulnerableOS, time.Time, error) SoftwareByCVE(ctx context.Context, cve string, teamID *uint) ([]*VulnerableSoftware, time.Time, error) - OSVersion(ctx context.Context, osVersionID uint, teamID *uint) (*OSVersion, *time.Time, error) + // OSVersion returns the OSVersion with the provided ID. If teamFilter is not nil, then the OSVersion is filtered. + // The returned OSVersion is accompanied by the time it was last updated. + OSVersion(ctx context.Context, osVersionID uint, teamFilter *TeamFilter) (*OSVersion, *time.Time, error) UpdateOSVersions(ctx context.Context) error /////////////////////////////////////////////////////////////////////////////// diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 1d874c341..0b1606e67 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -244,13 +244,13 @@ type GetMunkiIssueFunc func(ctx context.Context, munkiIssueID uint) (*fleet.Munk type GetMDMSolutionFunc func(ctx context.Context, mdmID uint) (*fleet.MDMSolution, error) -type OSVersionsFunc func(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) +type OSVersionsFunc func(ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string) (*fleet.OSVersions, error) type OSVersionsByCVEFunc func(ctx context.Context, cve string, teamID *uint) ([]*fleet.VulnerableOS, time.Time, error) type SoftwareByCVEFunc func(ctx context.Context, cve string, teamID *uint) ([]*fleet.VulnerableSoftware, time.Time, error) -type OSVersionFunc func(ctx context.Context, osVersionID uint, teamID *uint) (*fleet.OSVersion, *time.Time, error) +type OSVersionFunc func(ctx context.Context, osVersionID uint, teamFilter *fleet.TeamFilter) (*fleet.OSVersion, *time.Time, error) type UpdateOSVersionsFunc func(ctx context.Context) error @@ -2908,11 +2908,11 @@ func (s *DataStore) GetMDMSolution(ctx context.Context, mdmID uint) (*fleet.MDMS return s.GetMDMSolutionFunc(ctx, mdmID) } -func (s *DataStore) OSVersions(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) { +func (s *DataStore) OSVersions(ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string) (*fleet.OSVersions, error) { s.mu.Lock() s.OSVersionsFuncInvoked = true s.mu.Unlock() - return s.OSVersionsFunc(ctx, teamID, platform, name, version) + return s.OSVersionsFunc(ctx, teamFilter, platform, name, version) } func (s *DataStore) OSVersionsByCVE(ctx context.Context, cve string, teamID *uint) ([]*fleet.VulnerableOS, time.Time, error) { @@ -2929,11 +2929,11 @@ func (s *DataStore) SoftwareByCVE(ctx context.Context, cve string, teamID *uint) return s.SoftwareByCVEFunc(ctx, cve, teamID) } -func (s *DataStore) OSVersion(ctx context.Context, osVersionID uint, teamID *uint) (*fleet.OSVersion, *time.Time, error) { +func (s *DataStore) OSVersion(ctx context.Context, osVersionID uint, teamFilter *fleet.TeamFilter) (*fleet.OSVersion, *time.Time, error) { s.mu.Lock() s.OSVersionFuncInvoked = true s.mu.Unlock() - return s.OSVersionFunc(ctx, osVersionID, teamID) + return s.OSVersionFunc(ctx, osVersionID, teamFilter) } func (s *DataStore) UpdateOSVersions(ctx context.Context) error { diff --git a/server/service/hosts.go b/server/service/hosts.go index c01d1047f..2591df127 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -1816,17 +1816,33 @@ func (svc *Service) OSVersions(ctx context.Context, teamID *uint, platform *stri return nil, count, nil, &fleet.BadRequestError{Message: "Invalid order key"} } - osVersions, err := svc.ds.OSVersions(ctx, teamID, platform, name, version) - if err != nil && fleet.IsNotFound(err) { - // differentiate case where team was added after UpdateOSVersions last ran - if teamID != nil && *teamID > 0 { - // most of the time, team should exist so checking here saves unnecessary db calls - _, err := svc.ds.Team(ctx, *teamID) - if err != nil { - return nil, count, nil, err - } + if teamID != nil { + // This auth check ensures we return 403 if the user doesn't have access to the team + if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{TeamID: teamID}, fleet.ActionRead); err != nil { + return nil, count, nil, err } - // if team exists but stats have not yet been gathered, return empty JSON array + exists, err := svc.ds.TeamExists(ctx, *teamID) + if err != nil { + return nil, count, nil, ctxerr.Wrap(ctx, err, "checking if team exists") + } else if !exists { + return nil, count, nil, fleet.NewInvalidArgumentError("team_id", fmt.Sprintf("team %d does not exist", *teamID)). + WithStatus(http.StatusNotFound) + } + } + + vc, ok := viewer.FromContext(ctx) + if !ok { + return nil, count, nil, fleet.ErrNoContext + } + osVersions, err := svc.ds.OSVersions( + ctx, &fleet.TeamFilter{ + User: vc.User, + IncludeObserver: true, + TeamID: teamID, + }, platform, name, version, + ) + if err != nil && fleet.IsNotFound(err) { + // It is possible that os exists, but aggregation job has not run yet. osVersions = &fleet.OSVersions{} } else if err != nil { return nil, count, nil, err @@ -1913,15 +1929,36 @@ func (svc *Service) OSVersion(ctx context.Context, osID uint, teamID *uint, incl } if teamID != nil { + // This auth check ensures we return 403 if the user doesn't have access to the team + if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{TeamID: teamID}, fleet.ActionRead); err != nil { + return nil, nil, err + } exists, err := svc.ds.TeamExists(ctx, *teamID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "checking if team exists") } else if !exists { - return nil, nil, authz.ForbiddenWithInternal("team does not exist", nil, nil, nil) + return nil, nil, fleet.NewInvalidArgumentError("team_id", fmt.Sprintf("team %d does not exist", *teamID)). + WithStatus(http.StatusNotFound) } } - osVersion, updateTime, err := svc.ds.OSVersion(ctx, osID, teamID) + + vc, ok := viewer.FromContext(ctx) + if !ok { + return nil, nil, fleet.ErrNoContext + } + osVersion, updateTime, err := svc.ds.OSVersion( + ctx, osID, &fleet.TeamFilter{ + User: vc.User, + IncludeObserver: true, + TeamID: teamID, + }, + ) if err != nil { + if fleet.IsNotFound(err) { + // We return an empty result here to be consistent with the fleet/os_versions behavior. + // It is possible the os version exists, but the aggregation job has not run yet. + return nil, nil, nil + } return nil, nil, err } diff --git a/server/service/hosts_test.go b/server/service/hosts_test.go index 621b9c7bf..4358399eb 100644 --- a/server/service/hosts_test.go +++ b/server/service/hosts_test.go @@ -994,26 +994,19 @@ func TestEmptyTeamOSVersions(t *testing.T) { testVersions := []fleet.OSVersion{{HostsCount: 1, Name: "macOS 12.1", Platform: "darwin"}} - ds.TeamFunc = func(ctx context.Context, teamID uint) (*fleet.Team, error) { - if teamID == 1 { - return &fleet.Team{ - Name: "team1", - }, nil + ds.TeamExistsFunc = func(ctx context.Context, teamID uint) (bool, error) { + if teamID == 3 { + return false, nil } - if teamID == 2 { - return &fleet.Team{ - Name: "team2", - }, nil - } - - return nil, newNotFoundError() + return true, nil } - - ds.OSVersionsFunc = func(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) { - if *teamID == 1 { + ds.OSVersionsFunc = func( + ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string, + ) (*fleet.OSVersions, error) { + if *teamFilter.TeamID == 1 { return &fleet.OSVersions{CountsUpdatedAt: time.Now(), OSVersions: testVersions}, nil } - if *teamID == 4 { + if *teamFilter.TeamID == 4 { return nil, errors.New("some unknown error") } @@ -1037,7 +1030,7 @@ func TestEmptyTeamOSVersions(t *testing.T) { // team does not exist _, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), ptr.Uint(3), ptr.String("darwin"), nil, nil, fleet.ListOptions{}, false) require.Error(t, err) - require.Equal(t, "not found", fmt.Sprint(err)) + require.Contains(t, fmt.Sprint(err), "does not exist") // some unknown error _, _, _, err = svc.OSVersions(test.UserContext(ctx, test.UserAdmin), ptr.Uint(4), ptr.String("darwin"), nil, nil, fleet.ListOptions{}, false) @@ -1058,7 +1051,9 @@ func TestOSVersionsListOptions(t *testing.T) { {HostsCount: 6, NameOnly: "Ubuntu 21.04", Platform: "ubuntu"}, } - ds.OSVersionsFunc = func(ctx context.Context, teamID *uint, platform *string, name *string, version *string) (*fleet.OSVersions, error) { + ds.OSVersionsFunc = func( + ctx context.Context, teamFilter *fleet.TeamFilter, platform *string, name *string, version *string, + ) (*fleet.OSVersions, error) { return &fleet.OSVersions{CountsUpdatedAt: time.Now(), OSVersions: testVersions}, nil } diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 5cbea73c3..02090bb1c 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -7903,7 +7903,8 @@ func (s *integrationTestSuite) TestOSVersions() { require.Equal(t, &expectedVersion, osVersionResp.OSVersion) // invalid id - s.DoJSON("GET", "/api/latest/fleet/os_versions/999", nil, http.StatusNotFound, &osVersionResp) + s.DoJSON("GET", "/api/latest/fleet/os_versions/999", nil, http.StatusOK, &osVersionResp) + assert.Zero(t, osVersionResp.OSVersion.HostsCount) // name and version filters s.DoJSON("GET", "/api/latest/fleet/os_versions", nil, http.StatusOK, &osVersionsResp, "os_name", "Windows 11 Pro 21H2", "os_version", "10.0.22000.2") diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 0096e594c..5273f5eee 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -3406,14 +3406,15 @@ func (s *integrationEnterpriseTestSuite) TestOSVersions() { require.Equal(t, *vulnMeta[0].CISAKnownExploit, **osVersionsResp.OSVersions[0].Vulnerabilities[0].CISAKnownExploit) require.Equal(t, *vulnMeta[0].Published, **osVersionsResp.OSVersions[0].Vulnerabilities[0].CVEPublished) require.Equal(t, vulnMeta[0].Description, **osVersionsResp.OSVersions[0].Vulnerabilities[0].Description) + expectedOSVersion := osVersionsResp.OSVersions[0] var osVersionResp getOSVersionResponse s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusOK, &osVersionResp) - require.Equal(t, &osVersionsResp.OSVersions[0], osVersionResp.OSVersion) + require.Equal(t, &expectedOSVersion, osVersionResp.OSVersion) // OS versions with invalid team s.DoJSON( - "GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusForbidden, &osVersionResp, "team_id", + "GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusNotFound, &osVersionResp, "team_id", "99999", ) @@ -3428,19 +3429,73 @@ func (s *integrationEnterpriseTestSuite) TestOSVersions() { ) osVersionResp = getOSVersionResponse{} s.DoJSON( - "GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusNotFound, &osVersionResp, "team_id", + "GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusOK, &osVersionResp, "team_id", fmt.Sprintf("%d", tr.Team.ID), ) + assert.Zero(t, osVersionResp.OSVersion.HostsCount) // return empty json if UpdateOSVersions cron hasn't run yet for new team - team, err := s.ds.NewTeam(context.Background(), &fleet.Team{Name: "new team"}) + team0, err := s.ds.NewTeam(context.Background(), &fleet.Team{Name: "new team"}) require.NoError(t, err) - require.NoError(t, s.ds.AddHostsToTeam(context.Background(), &team.ID, []uint{hosts[0].ID})) - s.DoJSON("GET", "/api/latest/fleet/os_versions", nil, http.StatusOK, &osVersionsResp, "team_id", fmt.Sprintf("%d", team.ID)) + require.NoError(t, s.ds.AddHostsToTeam(context.Background(), &team0.ID, []uint{hosts[0].ID})) + s.DoJSON("GET", "/api/latest/fleet/os_versions", nil, http.StatusOK, &osVersionsResp, "team_id", fmt.Sprintf("%d", team0.ID)) require.Len(t, osVersionsResp.OSVersions, 0) // return err if team_id is invalid s.DoJSON("GET", "/api/latest/fleet/os_versions", nil, http.StatusBadRequest, &osVersionsResp, "team_id", "invalid") + + // Create another team and a team user + team1, err := s.ds.NewTeam( + context.Background(), &fleet.Team{ + ID: 42, + Name: "team1-os_version", + Description: "desc team1", + }, + ) + require.NoError(t, err) + // Create a new admin for team1. + password := test.GoodPassword + email := "admin-team1-os_version@example.com" + u := &fleet.User{ + Name: "admin team1", + Email: email, + GlobalRole: nil, + Teams: []fleet.UserTeam{ + { + Team: *team1, + Role: fleet.RoleAdmin, + }, + }, + } + require.NoError(t, u.SetPassword(password, 10, 10)) + _, err = s.ds.NewUser(context.Background(), u) + require.NoError(t, err) + + s.setTokenForTest(t, email, test.GoodPassword) + + // generate aggregated stats + require.NoError(t, s.ds.UpdateOSVersions(context.Background())) + // team1 user does not have access to team0 host + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/os_versions"), nil, http.StatusOK, &osVersionsResp) + assert.Empty(t, osVersionsResp.OSVersions) + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusOK, &osVersionResp) + assert.Zero(t, osVersionResp.OSVersion.HostsCount) + + // Move host from team0 to team1 + require.NoError(t, s.ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{hosts[0].ID})) + require.NoError(t, s.ds.UpdateOSVersions(context.Background())) + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/os_versions"), nil, http.StatusOK, &osVersionsResp) + require.Len(t, osVersionsResp.OSVersions, 1) + assert.Equal(t, expectedOSVersion, osVersionsResp.OSVersions[0]) + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusOK, &osVersionResp) + require.Equal(t, &expectedOSVersion, osVersionResp.OSVersion) + + // Team user is forbidden to access invalid team + s.DoJSON( + "GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusForbidden, &osVersionResp, "team_id", + "99999", + ) + } func (s *integrationEnterpriseTestSuite) TestMDMNotConfiguredEndpoints() {