fleet/server/service/vulnerabilities_test.go

184 lines
4.8 KiB
Go
Raw Normal View History

package service
import (
"context"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/require"
)
func TestListVulnerabilities(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}})
ds.ListVulnerabilitiesFunc = func(cxt context.Context, opt fleet.VulnListOptions) ([]fleet.VulnerabilityWithMetadata, *fleet.PaginationMetadata, error) {
return []fleet.VulnerabilityWithMetadata{
{
CVE: fleet.CVE{
CVE: "CVE-2019-1234",
Description: ptr.StringPtr("A vulnerability"),
},
CreatedAt: time.Now(),
HostsCount: 10,
},
}, nil, nil
}
t.Run("no list options", func(t *testing.T) {
_, _, err := svc.ListVulnerabilities(ctx, fleet.VulnListOptions{})
require.NoError(t, err)
})
t.Run("can only sort by supported columns", func(t *testing.T) {
// invalid order key
opts := fleet.VulnListOptions{ListOptions: fleet.ListOptions{
OrderKey: "invalid",
}, ValidSortColumns: freeValidVulnSortColumns}
_, _, err := svc.ListVulnerabilities(ctx, opts)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid order key")
// valid order key
opts.OrderKey = "cve"
_, _, err = svc.ListVulnerabilities(ctx, opts)
require.NoError(t, err)
})
}
2024-02-14 21:42:16 +00:00
func TestVulnerabilitesAuth(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
ds.ListVulnerabilitiesFunc = func(cxt context.Context, opt fleet.VulnListOptions) ([]fleet.VulnerabilityWithMetadata, *fleet.PaginationMetadata, error) {
return []fleet.VulnerabilityWithMetadata{}, &fleet.PaginationMetadata{}, nil
}
ds.VulnerabilityFunc = func(cxt context.Context, cve string, teamID *uint, includeCVEScores bool) (*fleet.VulnerabilityWithMetadata, error) {
return &fleet.VulnerabilityWithMetadata{}, nil
}
ds.CountVulnerabilitiesFunc = func(cxt context.Context, opt fleet.VulnListOptions) (uint, error) {
return 0, nil
}
ds.TeamExistsFunc = func(cxt context.Context, teamID uint) (bool, error) {
return true, nil
}
2024-02-14 21:42:16 +00:00
for _, tc := range []struct {
name string
user *fleet.User
shouldFailGlobalRead bool
shouldFailTeamRead bool
}{
{
name: "global-admin",
user: &fleet.User{
ID: 1,
GlobalRole: ptr.String(fleet.RoleAdmin),
},
shouldFailGlobalRead: false,
shouldFailTeamRead: false,
},
{
name: "global-maintainer",
user: &fleet.User{
ID: 1,
GlobalRole: ptr.String(fleet.RoleMaintainer),
},
shouldFailGlobalRead: false,
shouldFailTeamRead: false,
},
{
name: "global-observer",
user: &fleet.User{
ID: 1,
GlobalRole: ptr.String(fleet.RoleObserver),
},
shouldFailGlobalRead: false,
shouldFailTeamRead: false,
},
{
name: "team-admin-belongs-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 1},
Role: fleet.RoleAdmin,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: false,
},
{
name: "team-maintainer-belongs-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 1},
Role: fleet.RoleMaintainer,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: false,
},
{
name: "team-observer-belongs-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 1},
Role: fleet.RoleObserver,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: false,
},
{
name: "team-admin-does-not-belong-to-team",
user: &fleet.User{
ID: 1,
Teams: []fleet.UserTeam{{
Team: fleet.Team{ID: 2},
Role: fleet.RoleAdmin,
}},
},
shouldFailGlobalRead: true,
shouldFailTeamRead: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx = viewer.NewContext(ctx, viewer.Viewer{User: tc.user})
_, _, err := svc.ListVulnerabilities(ctx, fleet.VulnListOptions{})
checkAuthErr(t, tc.shouldFailGlobalRead, err)
_, _, err = svc.ListVulnerabilities(ctx, fleet.VulnListOptions{
TeamID: 1,
})
checkAuthErr(t, tc.shouldFailTeamRead, err)
_, err = svc.CountVulnerabilities(ctx, fleet.VulnListOptions{})
checkAuthErr(t, tc.shouldFailGlobalRead, err)
_, err = svc.CountVulnerabilities(ctx, fleet.VulnListOptions{
TeamID: 1,
})
checkAuthErr(t, tc.shouldFailTeamRead, err)
_, err = svc.Vulnerability(ctx, "CVE-2019-1234", nil, false)
checkAuthErr(t, tc.shouldFailGlobalRead, err)
_, err = svc.Vulnerability(ctx, "CVE-2019-1234", ptr.Uint(1), false)
checkAuthErr(t, tc.shouldFailTeamRead, err)
})
}
}