fleet/server/service/queries_test.go
Lucas Manuel Rodriguez 39f1029390
Make activities typed and auto-generate docs for each type (#9069)
* Make activities type and auto-generate docs for each type

* Add pageOrderInSection to not break site

* Add do not edit note to generated file

* Add make generate-doc step

* Fix main merge
2022-12-23 13:05:16 -03:00

225 lines
6.4 KiB
Go

package service
import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewQueryAttach(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
name := "bad"
query := "attach '/nope' as bad"
_, err := svc.NewQuery(
ctx,
fleet.QueryPayload{Name: &name, Query: &query},
)
require.Error(t, err)
}
func TestFilterQueriesForObserver(t *testing.T) {
require.True(t, onlyShowObserverCanRunQueries(&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)}))
require.False(t, onlyShowObserverCanRunQueries(&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)}))
require.False(t, onlyShowObserverCanRunQueries(&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}))
require.True(t, onlyShowObserverCanRunQueries(&fleet.User{Teams: []fleet.UserTeam{{Role: fleet.RoleObserver}}}))
require.True(t, onlyShowObserverCanRunQueries(&fleet.User{Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver},
{Role: fleet.RoleObserver},
}}))
require.False(t, onlyShowObserverCanRunQueries(&fleet.User{Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver},
{Role: fleet.RoleMaintainer},
}}))
}
func TestListQueries(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
cases := [...]struct {
title string
user *fleet.User
expectedOpts fleet.ListQueryOptions
}{
{
title: "global admin",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
expectedOpts: fleet.ListQueryOptions{OnlyObserverCanRun: false},
},
{
title: "global observer",
user: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
expectedOpts: fleet.ListQueryOptions{OnlyObserverCanRun: true},
},
{
title: "team maintainer",
user: &fleet.User{Teams: []fleet.UserTeam{{Role: fleet.RoleMaintainer}}},
expectedOpts: fleet.ListQueryOptions{OnlyObserverCanRun: false},
},
}
var calledWithOpts fleet.ListQueryOptions
ds.ListQueriesFunc = func(ctx context.Context, opt fleet.ListQueryOptions) ([]*fleet.Query, error) {
calledWithOpts = opt
return []*fleet.Query{}, nil
}
for _, tt := range cases {
t.Run(tt.title, func(t *testing.T) {
viewerCtx := viewer.NewContext(ctx, viewer.Viewer{User: tt.user})
_, err := svc.ListQueries(viewerCtx, fleet.ListOptions{})
require.NoError(t, err)
assert.Equal(t, tt.expectedOpts, calledWithOpts)
})
}
}
func TestQueryAuth(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil)
authoredQueryID := uint(1)
authoredQueryName := "authored"
queryName := map[uint]string{
authoredQueryID: authoredQueryName,
2: "not authored",
}
teamMaintainer := &fleet.User{ID: 42, Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}
ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
return query, nil
}
ds.QueryByNameFunc = func(ctx context.Context, name string, opts ...fleet.OptionalArg) (*fleet.Query, error) {
if name == authoredQueryName {
return &fleet.Query{ID: 99, AuthorID: ptr.Uint(teamMaintainer.ID)}, nil
}
return &fleet.Query{ID: 8888, AuthorID: ptr.Uint(6666)}, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
return nil
}
ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) {
if id == authoredQueryID {
return &fleet.Query{ID: 99, AuthorID: ptr.Uint(teamMaintainer.ID)}, nil
}
return &fleet.Query{ID: 8888, AuthorID: ptr.Uint(6666)}, nil
}
ds.SaveQueryFunc = func(ctx context.Context, query *fleet.Query) error {
return nil
}
ds.DeleteQueryFunc = func(ctx context.Context, name string) error {
return nil
}
ds.DeleteQueriesFunc = func(ctx context.Context, ids []uint) (uint, error) {
return 0, nil
}
ds.ListQueriesFunc = func(ctx context.Context, opts fleet.ListQueryOptions) ([]*fleet.Query, error) {
return nil, nil
}
ds.ApplyQueriesFunc = func(ctx context.Context, authID uint, queries []*fleet.Query) error {
return nil
}
testCases := []struct {
name string
user *fleet.User
qid uint
shouldFailWrite bool
shouldFailRead bool
shouldFailNew bool
}{
{
"global admin",
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
authoredQueryID,
false,
false,
false,
},
{
"global maintainer",
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
authoredQueryID,
false,
false,
false,
},
{
"global observer",
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
authoredQueryID,
true,
false,
true,
},
{
"team maintainer, author of the query",
teamMaintainer,
authoredQueryID,
false,
false,
false,
},
{
"team maintainer, NOT author of the query",
teamMaintainer,
2,
true,
false,
false,
},
{
"team observer",
&fleet.User{ID: 48, Teams: []fleet.UserTeam{{Team: fleet.Team{ID: authoredQueryID}, Role: fleet.RoleObserver}}},
2,
true,
false,
true,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
ctx := viewer.NewContext(ctx, viewer.Viewer{User: tt.user})
_, err := svc.NewQuery(ctx, fleet.QueryPayload{Name: ptr.String("name"), Query: ptr.String("select 1")})
checkAuthErr(t, tt.shouldFailNew, err)
_, err = svc.ModifyQuery(ctx, tt.qid, fleet.QueryPayload{})
checkAuthErr(t, tt.shouldFailWrite, err)
err = svc.DeleteQuery(ctx, queryName[tt.qid])
checkAuthErr(t, tt.shouldFailWrite, err)
err = svc.DeleteQueryByID(ctx, tt.qid)
checkAuthErr(t, tt.shouldFailWrite, err)
_, err = svc.DeleteQueries(ctx, []uint{tt.qid})
checkAuthErr(t, tt.shouldFailWrite, err)
_, err = svc.GetQuery(ctx, tt.qid)
checkAuthErr(t, tt.shouldFailRead, err)
_, err = svc.ListQueries(ctx, fleet.ListOptions{})
checkAuthErr(t, tt.shouldFailRead, err)
err = svc.ApplyQuerySpecs(ctx, []*fleet.QuerySpec{{Name: queryName[tt.qid], Query: "SELECT 1"}})
checkAuthErr(t, tt.shouldFailWrite, err)
_, err = svc.GetQuerySpecs(ctx)
checkAuthErr(t, tt.shouldFailRead, err)
_, err = svc.GetQuerySpec(ctx, queryName[tt.qid])
checkAuthErr(t, tt.shouldFailRead, err)
})
}
}