mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
482 lines
14 KiB
Go
482 lines
14 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/WatchBeam/clock"
|
|
"github.com/fleetdm/fleet/v4/server/authz"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/mock"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/fleetdm/fleet/v4/server/test"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestHostDetails(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := &Service{ds: ds}
|
|
|
|
host := &fleet.Host{ID: 3}
|
|
expectedLabels := []*fleet.Label{
|
|
{
|
|
Name: "foobar",
|
|
Description: "the foobar label",
|
|
},
|
|
}
|
|
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
|
|
return expectedLabels, nil
|
|
}
|
|
expectedPacks := []*fleet.Pack{
|
|
{
|
|
Name: "pack1",
|
|
},
|
|
{
|
|
Name: "pack2",
|
|
},
|
|
}
|
|
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
|
|
return expectedPacks, nil
|
|
}
|
|
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host) error {
|
|
return nil
|
|
}
|
|
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
hostDetail, err := svc.getHostDetails(test.UserContext(test.UserAdmin), host)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, expectedLabels, hostDetail.Labels)
|
|
assert.Equal(t, expectedPacks, hostDetail.Packs)
|
|
}
|
|
|
|
func TestHostAuth(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
teamHost := &fleet.Host{TeamID: ptr.Uint(1)}
|
|
globalHost := &fleet.Host{}
|
|
|
|
ds.DeleteHostFunc = func(ctx context.Context, hid uint) error {
|
|
return nil
|
|
}
|
|
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
|
if id == 1 {
|
|
return teamHost, nil
|
|
}
|
|
return globalHost, nil
|
|
}
|
|
ds.HostFunc = func(ctx context.Context, id uint, skipLoadingExtras bool) (*fleet.Host, error) {
|
|
if id == 1 {
|
|
return teamHost, nil
|
|
}
|
|
return globalHost, nil
|
|
}
|
|
ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
|
|
if identifier == "1" {
|
|
return teamHost, nil
|
|
}
|
|
return globalHost, nil
|
|
}
|
|
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
|
|
return nil, nil
|
|
}
|
|
ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host) error {
|
|
return nil
|
|
}
|
|
ds.ListLabelsForHostFunc = func(ctx context.Context, hid uint) ([]*fleet.Label, error) {
|
|
return nil, nil
|
|
}
|
|
ds.ListPacksForHostFunc = func(ctx context.Context, hid uint) (packs []*fleet.Pack, err error) {
|
|
return nil, nil
|
|
}
|
|
ds.AddHostsToTeamFunc = func(ctx context.Context, teamID *uint, hostIDs []uint) error {
|
|
return nil
|
|
}
|
|
ds.ListPoliciesForHostFunc = func(ctx context.Context, host *fleet.Host) ([]*fleet.HostPolicy, error) {
|
|
return nil, nil
|
|
}
|
|
ds.DeleteHostsFunc = func(ctx context.Context, ids []uint) error {
|
|
return nil
|
|
}
|
|
ds.UpdateHostRefetchRequestedFunc = func(ctx context.Context, id uint, value bool) error {
|
|
if id == 1 {
|
|
teamHost.RefetchRequested = true
|
|
} else {
|
|
globalHost.RefetchRequested = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
user *fleet.User
|
|
shouldFailGlobalWrite bool
|
|
shouldFailGlobalRead bool
|
|
shouldFailTeamWrite bool
|
|
shouldFailTeamRead bool
|
|
}{
|
|
{
|
|
"global admin",
|
|
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
|
false,
|
|
false,
|
|
false,
|
|
false,
|
|
},
|
|
{
|
|
"global maintainer",
|
|
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
|
|
false,
|
|
false,
|
|
false,
|
|
false,
|
|
},
|
|
{
|
|
"global observer",
|
|
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
|
true,
|
|
false,
|
|
true,
|
|
false,
|
|
},
|
|
{
|
|
"team maintainer, belongs to team",
|
|
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
|
|
true,
|
|
true,
|
|
false,
|
|
false,
|
|
},
|
|
{
|
|
"team observer, belongs to team",
|
|
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
|
|
true,
|
|
true,
|
|
true,
|
|
false,
|
|
},
|
|
{
|
|
"team maintainer, DOES NOT belong to team",
|
|
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleMaintainer}}},
|
|
true,
|
|
true,
|
|
true,
|
|
true,
|
|
},
|
|
{
|
|
"team observer, DOES NOT belong to team",
|
|
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 2}, Role: fleet.RoleObserver}}},
|
|
true,
|
|
true,
|
|
true,
|
|
true,
|
|
},
|
|
}
|
|
for _, tt := range testCases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
|
|
|
|
_, err := svc.GetHost(ctx, 1)
|
|
checkAuthErr(t, tt.shouldFailTeamRead, err)
|
|
|
|
_, err = svc.HostByIdentifier(ctx, "1")
|
|
checkAuthErr(t, tt.shouldFailTeamRead, err)
|
|
|
|
_, err = svc.GetHost(ctx, 2)
|
|
checkAuthErr(t, tt.shouldFailGlobalRead, err)
|
|
|
|
_, err = svc.HostByIdentifier(ctx, "2")
|
|
checkAuthErr(t, tt.shouldFailGlobalRead, err)
|
|
|
|
err = svc.DeleteHost(ctx, 1)
|
|
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
|
|
|
err = svc.DeleteHost(ctx, 2)
|
|
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
|
|
|
err = svc.DeleteHosts(ctx, []uint{1}, fleet.HostListOptions{}, nil)
|
|
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
|
|
|
err = svc.DeleteHosts(ctx, []uint{2}, fleet.HostListOptions{}, nil)
|
|
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
|
|
|
err = svc.AddHostsToTeam(ctx, ptr.Uint(1), []uint{1})
|
|
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
|
|
|
err = svc.AddHostsToTeamByFilter(ctx, ptr.Uint(1), fleet.HostListOptions{}, nil)
|
|
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
|
|
|
err = svc.RefetchHost(ctx, 1)
|
|
checkAuthErr(t, tt.shouldFailTeamRead, err)
|
|
})
|
|
}
|
|
|
|
// List, GetHostSummary work for all
|
|
}
|
|
|
|
func TestListHosts(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
|
|
return []*fleet.Host{
|
|
{ID: 1},
|
|
}, nil
|
|
}
|
|
|
|
hosts, err := svc.ListHosts(test.UserContext(test.UserAdmin), fleet.HostListOptions{})
|
|
require.NoError(t, err)
|
|
require.Len(t, hosts, 1)
|
|
|
|
// anyone can list hosts
|
|
hosts, err = svc.ListHosts(test.UserContext(test.UserNoRoles), fleet.HostListOptions{})
|
|
require.NoError(t, err)
|
|
require.Len(t, hosts, 1)
|
|
|
|
// a user is required
|
|
_, err = svc.ListHosts(context.Background(), fleet.HostListOptions{})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
|
|
}
|
|
|
|
func TestGetHostSummary(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
ds.GenerateHostStatusStatisticsFunc = func(ctx context.Context, filter fleet.TeamFilter, now time.Time, platform *string) (*fleet.HostSummary, error) {
|
|
return &fleet.HostSummary{
|
|
OnlineCount: 1,
|
|
OfflineCount: 2,
|
|
MIACount: 3,
|
|
NewCount: 4,
|
|
TotalsHostsCount: 5,
|
|
}, nil
|
|
}
|
|
|
|
summary, err := svc.GetHostSummary(test.UserContext(test.UserAdmin), nil, nil)
|
|
require.NoError(t, err)
|
|
require.Nil(t, summary.TeamID)
|
|
require.Equal(t, uint(1), summary.OnlineCount)
|
|
require.Equal(t, uint(2), summary.OfflineCount)
|
|
require.Equal(t, uint(3), summary.MIACount)
|
|
require.Equal(t, uint(4), summary.NewCount)
|
|
require.Equal(t, uint(5), summary.TotalsHostsCount)
|
|
|
|
_, err = svc.GetHostSummary(test.UserContext(test.UserNoRoles), nil, nil)
|
|
require.NoError(t, err)
|
|
|
|
// a user is required
|
|
_, err = svc.GetHostSummary(context.Background(), nil, nil)
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
|
|
}
|
|
|
|
func TestDeleteHost(t *testing.T) {
|
|
ds := mysql.CreateMySQLDS(t)
|
|
defer ds.Close()
|
|
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
mockClock := clock.NewMockClock()
|
|
host := test.NewHost(t, ds, "foo", "192.168.1.10", "1", "1", mockClock.Now())
|
|
assert.NotZero(t, host.ID)
|
|
|
|
err := svc.DeleteHost(test.UserContext(test.UserAdmin), host.ID)
|
|
assert.Nil(t, err)
|
|
|
|
filter := fleet.TeamFilter{User: test.UserAdmin}
|
|
hosts, err := ds.ListHosts(context.Background(), filter, fleet.HostListOptions{})
|
|
assert.Nil(t, err)
|
|
assert.Len(t, hosts, 0)
|
|
}
|
|
|
|
func TestAddHostsToTeamByFilter(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
expectedHostIDs := []uint{1, 2, 4}
|
|
expectedTeam := (*uint)(nil)
|
|
|
|
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
|
|
var hosts []*fleet.Host
|
|
for _, id := range expectedHostIDs {
|
|
hosts = append(hosts, &fleet.Host{ID: id})
|
|
}
|
|
return hosts, nil
|
|
}
|
|
ds.AddHostsToTeamFunc = func(ctx context.Context, teamID *uint, hostIDs []uint) error {
|
|
assert.Equal(t, expectedTeam, teamID)
|
|
assert.Equal(t, expectedHostIDs, hostIDs)
|
|
return nil
|
|
}
|
|
|
|
require.NoError(t, svc.AddHostsToTeamByFilter(test.UserContext(test.UserAdmin), expectedTeam, fleet.HostListOptions{}, nil))
|
|
assert.True(t, ds.ListHostsFuncInvoked)
|
|
assert.True(t, ds.AddHostsToTeamFuncInvoked)
|
|
}
|
|
|
|
func TestAddHostsToTeamByFilterLabel(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
expectedHostIDs := []uint{6}
|
|
expectedTeam := ptr.Uint(1)
|
|
expectedLabel := ptr.Uint(2)
|
|
|
|
ds.ListHostsInLabelFunc = func(ctx context.Context, filter fleet.TeamFilter, lid uint, opt fleet.HostListOptions) ([]*fleet.Host, error) {
|
|
assert.Equal(t, *expectedLabel, lid)
|
|
var hosts []*fleet.Host
|
|
for _, id := range expectedHostIDs {
|
|
hosts = append(hosts, &fleet.Host{ID: id})
|
|
}
|
|
return hosts, nil
|
|
}
|
|
ds.AddHostsToTeamFunc = func(ctx context.Context, teamID *uint, hostIDs []uint) error {
|
|
assert.Equal(t, expectedHostIDs, hostIDs)
|
|
return nil
|
|
}
|
|
|
|
require.NoError(t, svc.AddHostsToTeamByFilter(test.UserContext(test.UserAdmin), expectedTeam, fleet.HostListOptions{}, expectedLabel))
|
|
assert.True(t, ds.ListHostsInLabelFuncInvoked)
|
|
assert.True(t, ds.AddHostsToTeamFuncInvoked)
|
|
}
|
|
|
|
func TestAddHostsToTeamByFilterEmptyHosts(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
|
|
return []*fleet.Host{}, nil
|
|
}
|
|
ds.AddHostsToTeamFunc = func(ctx context.Context, teamID *uint, hostIDs []uint) error {
|
|
return nil
|
|
}
|
|
|
|
require.NoError(t, svc.AddHostsToTeamByFilter(test.UserContext(test.UserAdmin), nil, fleet.HostListOptions{}, nil))
|
|
assert.True(t, ds.ListHostsFuncInvoked)
|
|
assert.False(t, ds.AddHostsToTeamFuncInvoked)
|
|
}
|
|
|
|
func TestRefetchHost(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
host := &fleet.Host{ID: 3}
|
|
|
|
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
|
return host, nil
|
|
}
|
|
ds.UpdateHostRefetchRequestedFunc = func(ctx context.Context, id uint, value bool) error {
|
|
assert.Equal(t, host.ID, id)
|
|
assert.True(t, value)
|
|
return nil
|
|
}
|
|
|
|
require.NoError(t, svc.RefetchHost(test.UserContext(test.UserAdmin), host.ID))
|
|
require.NoError(t, svc.RefetchHost(test.UserContext(test.UserObserver), host.ID))
|
|
require.NoError(t, svc.RefetchHost(test.UserContext(test.UserMaintainer), host.ID))
|
|
assert.True(t, ds.HostLiteFuncInvoked)
|
|
assert.True(t, ds.UpdateHostRefetchRequestedFuncInvoked)
|
|
}
|
|
|
|
func TestRefetchHostUserInTeams(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
host := &fleet.Host{ID: 3, TeamID: ptr.Uint(4)}
|
|
|
|
ds.HostLiteFunc = func(ctx context.Context, id uint) (*fleet.Host, error) {
|
|
return host, nil
|
|
}
|
|
ds.UpdateHostRefetchRequestedFunc = func(ctx context.Context, id uint, value bool) error {
|
|
assert.Equal(t, host.ID, id)
|
|
assert.True(t, value)
|
|
return nil
|
|
}
|
|
|
|
maintainer := &fleet.User{
|
|
Teams: []fleet.UserTeam{
|
|
{
|
|
Team: fleet.Team{ID: 4},
|
|
Role: fleet.RoleMaintainer,
|
|
},
|
|
},
|
|
}
|
|
require.NoError(t, svc.RefetchHost(test.UserContext(maintainer), host.ID))
|
|
assert.True(t, ds.HostLiteFuncInvoked)
|
|
assert.True(t, ds.UpdateHostRefetchRequestedFuncInvoked)
|
|
ds.HostLiteFuncInvoked, ds.UpdateHostRefetchRequestedFuncInvoked = false, false
|
|
|
|
observer := &fleet.User{
|
|
Teams: []fleet.UserTeam{
|
|
{
|
|
Team: fleet.Team{ID: 4},
|
|
Role: fleet.RoleObserver,
|
|
},
|
|
},
|
|
}
|
|
require.NoError(t, svc.RefetchHost(test.UserContext(observer), host.ID))
|
|
assert.True(t, ds.HostLiteFuncInvoked)
|
|
assert.True(t, ds.UpdateHostRefetchRequestedFuncInvoked)
|
|
}
|
|
|
|
func TestEmptyTeamOSVersions(t *testing.T) {
|
|
ds := new(mock.Store)
|
|
svc := newTestService(t, ds, nil, nil)
|
|
|
|
testVersions := []fleet.OSVersion{{HostsCount: 1, Name: "macOS 12.1", Platform: "darwin"}}
|
|
|
|
ds.TeamFunc = func(ctx context.Context, teamID uint) (*fleet.Team, error) {
|
|
if teamID == 1 {
|
|
return &fleet.Team{
|
|
Name: "team1",
|
|
}, nil
|
|
}
|
|
if teamID == 2 {
|
|
return &fleet.Team{
|
|
Name: "team2",
|
|
}, nil
|
|
}
|
|
|
|
return nil, notFoundError{}
|
|
}
|
|
|
|
ds.OSVersionsFunc = func(ctx context.Context, teamID *uint, platform *string) (*fleet.OSVersions, error) {
|
|
if *teamID == 1 {
|
|
return &fleet.OSVersions{CountsUpdatedAt: time.Now(), OSVersions: testVersions}, nil
|
|
}
|
|
if *teamID == 4 {
|
|
return nil, fmt.Errorf("some unknown error")
|
|
}
|
|
|
|
return nil, notFoundError{}
|
|
}
|
|
|
|
// team exists with stats
|
|
vers, err := svc.OSVersions(test.UserContext(test.UserAdmin), ptr.Uint(1), ptr.String("darwin"))
|
|
require.NoError(t, err)
|
|
assert.Len(t, vers.OSVersions, 1)
|
|
|
|
// team exists but no stats
|
|
vers, err = svc.OSVersions(test.UserContext(test.UserAdmin), ptr.Uint(2), ptr.String("darwin"))
|
|
require.NoError(t, err)
|
|
assert.Empty(t, vers.OSVersions)
|
|
|
|
// team does not exist
|
|
vers, err = svc.OSVersions(test.UserContext(test.UserAdmin), ptr.Uint(3), ptr.String("darwin"))
|
|
require.Error(t, err)
|
|
require.Equal(t, "not found", fmt.Sprint(err))
|
|
|
|
// some unknown error
|
|
vers, err = svc.OSVersions(test.UserContext(test.UserAdmin), ptr.Uint(4), ptr.String("darwin"))
|
|
require.Error(t, err)
|
|
require.Equal(t, "some unknown error", fmt.Sprint(err))
|
|
}
|