Migrate packs endpoints to new pattern (#3244)

This commit is contained in:
Martin Angers 2021-12-15 09:35:40 -05:00 committed by GitHub
parent af42a0850e
commit 73e1c801ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 832 additions and 1058 deletions

View File

@ -1,270 +0,0 @@
package service
import (
"context"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/endpoint"
)
type packResponse struct {
fleet.Pack
QueryCount uint `json:"query_count"`
// All current hosts in the pack. Hosts which are selected explicty and
// hosts which are part of a label.
TotalHostsCount uint `json:"total_hosts_count"`
// IDs of hosts which were explicitly selected.
HostIDs []uint `json:"host_ids"`
LabelIDs []uint `json:"label_ids"`
TeamIDs []uint `json:"team_ids"`
}
func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack) (*packResponse, error) {
opts := fleet.ListOptions{}
queries, err := svc.GetScheduledQueriesInPack(ctx, pack.ID, opts)
if err != nil {
return nil, err
}
hostMetrics, err := svc.CountHostsInTargets(
ctx,
nil,
fleet.HostTargets{HostIDs: pack.HostIDs, LabelIDs: pack.LabelIDs, TeamIDs: pack.TeamIDs},
)
if err != nil {
return nil, err
}
return &packResponse{
Pack: pack,
QueryCount: uint(len(queries)),
TotalHostsCount: hostMetrics.TotalHosts,
HostIDs: pack.HostIDs,
LabelIDs: pack.LabelIDs,
TeamIDs: pack.TeamIDs,
}, nil
}
////////////////////////////////////////////////////////////////////////////////
// List Packs
////////////////////////////////////////////////////////////////////////////////
type listPacksRequest struct {
ListOptions fleet.ListOptions
}
type listPacksResponse struct {
Packs []packResponse `json:"packs"`
Err error `json:"error,omitempty"`
}
func (r listPacksResponse) error() error { return r.Err }
func makeListPacksEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listPacksRequest)
packs, err := svc.ListPacks(ctx, fleet.PackListOptions{ListOptions: req.ListOptions, IncludeSystemPacks: false})
if err != nil {
return getPackResponse{Err: err}, nil
}
resp := listPacksResponse{Packs: make([]packResponse, len(packs))}
for i, pack := range packs {
packResp, err := packResponseForPack(ctx, svc, *pack)
if err != nil {
return getPackResponse{Err: err}, nil
}
resp.Packs[i] = *packResp
}
return resp, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Create Pack
////////////////////////////////////////////////////////////////////////////////
type createPackRequest struct {
payload fleet.PackPayload
}
type createPackResponse struct {
Pack packResponse `json:"pack,omitempty"`
Err error `json:"error,omitempty"`
}
func (r createPackResponse) error() error { return r.Err }
func makeCreatePackEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(createPackRequest)
pack, err := svc.NewPack(ctx, req.payload)
if err != nil {
return createPackResponse{Err: err}, nil
}
resp, err := packResponseForPack(ctx, svc, *pack)
if err != nil {
return createPackResponse{Err: err}, nil
}
return createPackResponse{
Pack: *resp,
}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Modify Pack
////////////////////////////////////////////////////////////////////////////////
type modifyPackRequest struct {
ID uint
payload fleet.PackPayload
}
type modifyPackResponse struct {
Pack packResponse `json:"pack,omitempty"`
Err error `json:"error,omitempty"`
}
func (r modifyPackResponse) error() error { return r.Err }
func makeModifyPackEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(modifyPackRequest)
pack, err := svc.ModifyPack(ctx, req.ID, req.payload)
if err != nil {
return modifyPackResponse{Err: err}, nil
}
resp, err := packResponseForPack(ctx, svc, *pack)
if err != nil {
return modifyPackResponse{Err: err}, nil
}
return modifyPackResponse{
Pack: *resp,
}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Delete Pack
////////////////////////////////////////////////////////////////////////////////
type deletePackRequest struct {
Name string
}
type deletePackResponse struct {
Err error `json:"error,omitempty"`
}
func (r deletePackResponse) error() error { return r.Err }
func makeDeletePackEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(deletePackRequest)
err := svc.DeletePack(ctx, req.Name)
if err != nil {
return deletePackResponse{Err: err}, nil
}
return deletePackResponse{}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Delete Pack By ID
////////////////////////////////////////////////////////////////////////////////
type deletePackByIDRequest struct {
ID uint
}
type deletePackByIDResponse struct {
Err error `json:"error,omitempty"`
}
func (r deletePackByIDResponse) error() error { return r.Err }
func makeDeletePackByIDEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(deletePackByIDRequest)
err := svc.DeletePackByID(ctx, req.ID)
if err != nil {
return deletePackByIDResponse{Err: err}, nil
}
return deletePackByIDResponse{}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Apply Pack Spec
////////////////////////////////////////////////////////////////////////////////
type applyPackSpecsRequest struct {
Specs []*fleet.PackSpec `json:"specs"`
}
type applyPackSpecsResponse struct {
Err error `json:"error,omitempty"`
}
func (r applyPackSpecsResponse) error() error { return r.Err }
func makeApplyPackSpecsEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(applyPackSpecsRequest)
_, err := svc.ApplyPackSpecs(ctx, req.Specs)
if err != nil {
return applyPackSpecsResponse{Err: err}, nil
}
return applyPackSpecsResponse{}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Get Pack Spec
////////////////////////////////////////////////////////////////////////////////
type getPackSpecsResponse struct {
Specs []*fleet.PackSpec `json:"specs"`
Err error `json:"error,omitempty"`
}
func (r getPackSpecsResponse) error() error { return r.Err }
func makeGetPackSpecsEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
specs, err := svc.GetPackSpecs(ctx)
if err != nil {
return getPackSpecsResponse{Err: err}, nil
}
return getPackSpecsResponse{Specs: specs}, nil
}
}
////////////////////////////////////////////////////////////////////////////////
// Get Pack Spec
////////////////////////////////////////////////////////////////////////////////
type getPackSpecResponse struct {
Spec *fleet.PackSpec `json:"specs,omitempty"`
Err error `json:"error,omitempty"`
}
func (r getPackSpecResponse) error() error { return r.Err }
func makeGetPackSpecEndpoint(svc fleet.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(getGenericSpecRequest)
spec, err := svc.GetPackSpec(ctx, req.Name)
if err != nil {
return getPackSpecResponse{Err: err}, nil
}
return getPackSpecResponse{Spec: spec}, nil
}
}

View File

@ -60,19 +60,11 @@ type FleetEndpoints struct {
GetQuerySpec endpoint.Endpoint
CreateDistributedQueryCampaign endpoint.Endpoint
CreateDistributedQueryCampaignByNames endpoint.Endpoint
CreatePack endpoint.Endpoint
ModifyPack endpoint.Endpoint
ListPacks endpoint.Endpoint
DeletePack endpoint.Endpoint
DeletePackByID endpoint.Endpoint
GetScheduledQueriesInPack endpoint.Endpoint
ScheduleQuery endpoint.Endpoint
GetScheduledQuery endpoint.Endpoint
ModifyScheduledQuery endpoint.Endpoint
DeleteScheduledQuery endpoint.Endpoint
ApplyPackSpecs endpoint.Endpoint
GetPackSpecs endpoint.Endpoint
GetPackSpec endpoint.Endpoint
EnrollAgent endpoint.Endpoint
GetClientConfig endpoint.Endpoint
GetDistributedQueries endpoint.Endpoint
@ -167,19 +159,11 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th
GetQuerySpec: authenticatedUser(svc, makeGetQuerySpecEndpoint(svc)),
CreateDistributedQueryCampaign: authenticatedUser(svc, makeCreateDistributedQueryCampaignEndpoint(svc)),
CreateDistributedQueryCampaignByNames: authenticatedUser(svc, makeCreateDistributedQueryCampaignByNamesEndpoint(svc)),
CreatePack: authenticatedUser(svc, makeCreatePackEndpoint(svc)),
ModifyPack: authenticatedUser(svc, makeModifyPackEndpoint(svc)),
ListPacks: authenticatedUser(svc, makeListPacksEndpoint(svc)),
DeletePack: authenticatedUser(svc, makeDeletePackEndpoint(svc)),
DeletePackByID: authenticatedUser(svc, makeDeletePackByIDEndpoint(svc)),
GetScheduledQueriesInPack: authenticatedUser(svc, makeGetScheduledQueriesInPackEndpoint(svc)),
ScheduleQuery: authenticatedUser(svc, makeScheduleQueryEndpoint(svc)),
GetScheduledQuery: authenticatedUser(svc, makeGetScheduledQueryEndpoint(svc)),
ModifyScheduledQuery: authenticatedUser(svc, makeModifyScheduledQueryEndpoint(svc)),
DeleteScheduledQuery: authenticatedUser(svc, makeDeleteScheduledQueryEndpoint(svc)),
ApplyPackSpecs: authenticatedUser(svc, makeApplyPackSpecsEndpoint(svc)),
GetPackSpecs: authenticatedUser(svc, makeGetPackSpecsEndpoint(svc)),
GetPackSpec: authenticatedUser(svc, makeGetPackSpecEndpoint(svc)),
CreateLabel: authenticatedUser(svc, makeCreateLabelEndpoint(svc)),
ModifyLabel: authenticatedUser(svc, makeModifyLabelEndpoint(svc)),
GetLabel: authenticatedUser(svc, makeGetLabelEndpoint(svc)),
@ -262,19 +246,11 @@ type fleetHandlers struct {
GetQuerySpec http.Handler
CreateDistributedQueryCampaign http.Handler
CreateDistributedQueryCampaignByNames http.Handler
CreatePack http.Handler
ModifyPack http.Handler
ListPacks http.Handler
DeletePack http.Handler
DeletePackByID http.Handler
GetScheduledQueriesInPack http.Handler
ScheduleQuery http.Handler
GetScheduledQuery http.Handler
ModifyScheduledQuery http.Handler
DeleteScheduledQuery http.Handler
ApplyPackSpecs http.Handler
GetPackSpecs http.Handler
GetPackSpec http.Handler
EnrollAgent http.Handler
GetClientConfig http.Handler
GetDistributedQueries http.Handler
@ -356,19 +332,11 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle
GetQuerySpec: newServer(e.GetQuerySpec, decodeGetGenericSpecRequest),
CreateDistributedQueryCampaign: newServer(e.CreateDistributedQueryCampaign, decodeCreateDistributedQueryCampaignRequest),
CreateDistributedQueryCampaignByNames: newServer(e.CreateDistributedQueryCampaignByNames, decodeCreateDistributedQueryCampaignByNamesRequest),
CreatePack: newServer(e.CreatePack, decodeCreatePackRequest),
ModifyPack: newServer(e.ModifyPack, decodeModifyPackRequest),
ListPacks: newServer(e.ListPacks, decodeListPacksRequest),
DeletePack: newServer(e.DeletePack, decodeDeletePackRequest),
DeletePackByID: newServer(e.DeletePackByID, decodeDeletePackByIDRequest),
GetScheduledQueriesInPack: newServer(e.GetScheduledQueriesInPack, decodeGetScheduledQueriesInPackRequest),
ScheduleQuery: newServer(e.ScheduleQuery, decodeScheduleQueryRequest),
GetScheduledQuery: newServer(e.GetScheduledQuery, decodeGetScheduledQueryRequest),
ModifyScheduledQuery: newServer(e.ModifyScheduledQuery, decodeModifyScheduledQueryRequest),
DeleteScheduledQuery: newServer(e.DeleteScheduledQuery, decodeDeleteScheduledQueryRequest),
ApplyPackSpecs: newServer(e.ApplyPackSpecs, decodeApplyPackSpecsRequest),
GetPackSpecs: newServer(e.GetPackSpecs, decodeNoParamsRequest),
GetPackSpec: newServer(e.GetPackSpec, decodeGetGenericSpecRequest),
EnrollAgent: newServer(e.EnrollAgent, decodeEnrollAgentRequest),
GetClientConfig: newServer(e.GetClientConfig, decodeGetClientConfigRequest),
GetDistributedQueries: newServer(e.GetDistributedQueries, decodeGetDistributedQueriesRequest),
@ -551,19 +519,11 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
r.Handle("/api/v1/fleet/queries/run", h.CreateDistributedQueryCampaign).Methods("POST").Name("create_distributed_query_campaign")
r.Handle("/api/v1/fleet/queries/run_by_names", h.CreateDistributedQueryCampaignByNames).Methods("POST").Name("create_distributed_query_campaign_by_names")
r.Handle("/api/v1/fleet/packs", h.CreatePack).Methods("POST").Name("create_pack")
r.Handle("/api/v1/fleet/packs/{id:[0-9]+}", h.ModifyPack).Methods("PATCH").Name("modify_pack")
r.Handle("/api/v1/fleet/packs", h.ListPacks).Methods("GET").Name("list_packs")
r.Handle("/api/v1/fleet/packs/{name}", h.DeletePack).Methods("DELETE").Name("delete_pack")
r.Handle("/api/v1/fleet/packs/id/{id:[0-9]+}", h.DeletePackByID).Methods("DELETE").Name("delete_pack_by_id")
r.Handle("/api/v1/fleet/packs/{id:[0-9]+}/scheduled", h.GetScheduledQueriesInPack).Methods("GET").Name("get_scheduled_queries_in_pack")
r.Handle("/api/v1/fleet/schedule", h.ScheduleQuery).Methods("POST").Name("schedule_query")
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.GetScheduledQuery).Methods("GET").Name("get_scheduled_query")
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.ModifyScheduledQuery).Methods("PATCH").Name("modify_scheduled_query")
r.Handle("/api/v1/fleet/schedule/{id:[0-9]+}", h.DeleteScheduledQuery).Methods("DELETE").Name("delete_scheduled_query")
r.Handle("/api/v1/fleet/spec/packs", h.ApplyPackSpecs).Methods("POST").Name("apply_pack_specs")
r.Handle("/api/v1/fleet/spec/packs", h.GetPackSpecs).Methods("GET").Name("get_pack_specs")
r.Handle("/api/v1/fleet/spec/packs/{name}", h.GetPackSpec).Methods("GET").Name("get_pack_spec")
r.Handle("/api/v1/fleet/labels", h.CreateLabel).Methods("POST").Name("create_label")
r.Handle("/api/v1/fleet/labels/{id:[0-9]+}", h.ModifyLabel).Methods("PATCH").Name("modify_label")
@ -642,6 +602,14 @@ func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kitht
e.POST("/api/v1/fleet/spec/policies", applyPolicySpecsEndpoint, applyPolicySpecsRequest{})
e.GET("/api/v1/fleet/packs/{id:[0-9]+}", getPackEndpoint, getPackRequest{})
e.POST("/api/v1/fleet/packs", createPackEndpoint, createPackRequest{})
e.PATCH("/api/v1/fleet/packs/{id:[0-9]+}", modifyPackEndpoint, modifyPackRequest{})
e.GET("/api/v1/fleet/packs", listPacksEndpoint, listPacksRequest{})
e.DELETE("/api/v1/fleet/packs/{name}", deletePackEndpoint, deletePackRequest{})
e.DELETE("/api/v1/fleet/packs/id/{id:[0-9]+}", deletePackByIDEndpoint, deletePackByIDRequest{})
e.POST("/api/v1/fleet/spec/packs", applyPackSpecsEndpoint, applyPackSpecsRequest{})
e.GET("/api/v1/fleet/spec/packs", getPackSpecsEndpoint, nil)
e.GET("/api/v1/fleet/spec/packs/{name}", getPackSpecEndpoint, getGenericSpecRequest{})
e.GET("/api/v1/fleet/software", listSoftwareEndpoint, listSoftwareRequest{})
e.GET("/api/v1/fleet/software/count", countSoftwareEndpoint, countSoftwareRequest{})

View File

@ -113,22 +113,6 @@ func TestAPIRoutes(t *testing.T) {
verb: "POST",
uri: "/api/v1/fleet/queries/run",
},
{
verb: "GET",
uri: "/api/v1/fleet/packs",
},
{
verb: "POST",
uri: "/api/v1/fleet/packs",
},
{
verb: "PATCH",
uri: "/api/v1/fleet/packs/1",
},
{
verb: "DELETE",
uri: "/api/v1/fleet/packs/1",
},
{
verb: "GET",
uri: "/api/v1/fleet/packs/1/scheduled",

View File

@ -6,7 +6,9 @@ import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"reflect"
"strings"
"testing"
"time"
@ -627,20 +629,76 @@ func (s *integrationTestSuite) TestCountSoftware() {
assert.Equal(t, 1, resp.Count)
}
func (s *integrationTestSuite) TestGetPack() {
func (s *integrationTestSuite) TestPacks() {
t := s.T()
pack := &fleet.Pack{
Name: t.Name(),
}
pack, err := s.ds.NewPack(context.Background(), pack)
require.NoError(t, err)
var packResp getPackResponse
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID), nil, http.StatusOK, &packResp)
require.Equal(t, packResp.Pack.ID, pack.ID)
// get non-existing pack
s.Do("GET", "/api/v1/fleet/packs/999", nil, http.StatusNotFound)
s.Do("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", pack.ID+1), nil, http.StatusNotFound)
// create some packs
packs := make([]fleet.Pack, 3)
for i := range packs {
req := &createPackRequest{
PackPayload: fleet.PackPayload{
Name: ptr.String(fmt.Sprintf("%s_%d", strings.ReplaceAll(t.Name(), "/", "_"), i)),
},
}
var createResp createPackResponse
s.DoJSON("POST", "/api/v1/fleet/packs", req, http.StatusOK, &createResp)
packs[i] = createResp.Pack.Pack
}
// get existing pack
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/packs/%d", packs[0].ID), nil, http.StatusOK, &packResp)
require.Equal(t, packs[0].ID, packResp.Pack.ID)
// list packs
var listResp listPacksResponse
s.DoJSON("GET", "/api/v1/fleet/packs", nil, http.StatusOK, &listResp, "per_page", "2", "order_key", "name")
require.Len(t, listResp.Packs, 2)
assert.Equal(t, packs[0].ID, listResp.Packs[0].ID)
assert.Equal(t, packs[1].ID, listResp.Packs[1].ID)
// get page 1
s.DoJSON("GET", "/api/v1/fleet/packs", nil, http.StatusOK, &listResp, "page", "1", "per_page", "2", "order_key", "name")
require.Len(t, listResp.Packs, 1)
assert.Equal(t, packs[2].ID, listResp.Packs[0].ID)
// get page 2, empty
s.DoJSON("GET", "/api/v1/fleet/packs", nil, http.StatusOK, &listResp, "page", "2", "per_page", "2", "order_key", "name")
require.Len(t, listResp.Packs, 0)
var delResp deletePackResponse
// delete non-existing pack by name
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/packs/%s", "zzz"), nil, http.StatusNotFound, &delResp)
// delete existing pack by name
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/packs/%s", url.PathEscape(packs[0].Name)), nil, http.StatusOK, &delResp)
// delete non-existing pack by id
var delIDResp deletePackByIDResponse
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/packs/id/%d", packs[2].ID+1), nil, http.StatusNotFound, &delIDResp)
// delete existing pack by id
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/packs/id/%d", packs[1].ID), nil, http.StatusOK, &delIDResp)
var modResp modifyPackResponse
// modify non-existing pack
req := &fleet.PackPayload{Name: ptr.String("updated_" + packs[2].Name)}
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/packs/%d", packs[2].ID+1), req, http.StatusNotFound, &modResp)
// modify existing pack
req = &fleet.PackPayload{Name: ptr.String("updated_" + packs[2].Name)}
s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/packs/%d", packs[2].ID), req, http.StatusOK, &modResp)
require.Equal(t, packs[2].ID, modResp.Pack.ID)
require.Contains(t, modResp.Pack.Name, "updated_")
// list packs, only packs[2] remains
s.DoJSON("GET", "/api/v1/fleet/packs", nil, http.StatusOK, &listResp, "per_page", "2", "order_key", "name")
require.Len(t, listResp.Packs, 1)
assert.Equal(t, packs[2].ID, listResp.Packs[0].ID)
}
func (s *integrationTestSuite) TestListHosts() {

View File

@ -2,10 +2,52 @@ package service
import (
"context"
"fmt"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
)
type packResponse struct {
fleet.Pack
QueryCount uint `json:"query_count"`
// All current hosts in the pack. Hosts which are selected explicty and
// hosts which are part of a label.
TotalHostsCount uint `json:"total_hosts_count"`
// IDs of hosts which were explicitly selected.
HostIDs []uint `json:"host_ids"`
LabelIDs []uint `json:"label_ids"`
TeamIDs []uint `json:"team_ids"`
}
func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack) (*packResponse, error) {
opts := fleet.ListOptions{}
queries, err := svc.GetScheduledQueriesInPack(ctx, pack.ID, opts)
if err != nil {
return nil, err
}
hostMetrics, err := svc.CountHostsInTargets(
ctx,
nil,
fleet.HostTargets{HostIDs: pack.HostIDs, LabelIDs: pack.LabelIDs, TeamIDs: pack.TeamIDs},
)
if err != nil {
return nil, err
}
return &packResponse{
Pack: pack,
QueryCount: uint(len(queries)),
TotalHostsCount: hostMetrics.TotalHosts,
HostIDs: pack.HostIDs,
LabelIDs: pack.LabelIDs,
TeamIDs: pack.TeamIDs,
}, nil
}
////////////////////////////////////////////////////////////////////////////////
// Get Pack
////////////////////////////////////////////////////////////////////////////////
@ -45,3 +87,446 @@ func (svc *Service) GetPack(ctx context.Context, id uint) (*fleet.Pack, error) {
return svc.ds.Pack(ctx, id)
}
////////////////////////////////////////////////////////////////////////////////
// Create Pack
////////////////////////////////////////////////////////////////////////////////
type createPackRequest struct {
fleet.PackPayload
}
type createPackResponse struct {
Pack packResponse `json:"pack,omitempty"`
Err error `json:"error,omitempty"`
}
func (r createPackResponse) error() error { return r.Err }
func createPackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*createPackRequest)
pack, err := svc.NewPack(ctx, req.PackPayload)
if err != nil {
return createPackResponse{Err: err}, nil
}
resp, err := packResponseForPack(ctx, svc, *pack)
if err != nil {
return createPackResponse{Err: err}, nil
}
return createPackResponse{
Pack: *resp,
}, nil
}
func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return nil, err
}
var pack fleet.Pack
if p.Name != nil {
pack.Name = *p.Name
}
if p.Description != nil {
pack.Description = *p.Description
}
if p.Platform != nil {
pack.Platform = *p.Platform
}
if p.Disabled != nil {
pack.Disabled = *p.Disabled
}
if p.HostIDs != nil {
pack.HostIDs = *p.HostIDs
}
if p.LabelIDs != nil {
pack.LabelIDs = *p.LabelIDs
}
if p.TeamIDs != nil {
pack.TeamIDs = *p.TeamIDs
}
_, err := svc.ds.NewPack(ctx, &pack)
if err != nil {
return nil, err
}
if err := svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeCreatedPack,
&map[string]interface{}{"pack_id": pack.ID, "pack_name": pack.Name},
); err != nil {
return nil, err
}
return &pack, nil
}
////////////////////////////////////////////////////////////////////////////////
// Modify Pack
////////////////////////////////////////////////////////////////////////////////
type modifyPackRequest struct {
ID uint `json:"-" url:"id"`
fleet.PackPayload
}
type modifyPackResponse struct {
Pack packResponse `json:"pack,omitempty"`
Err error `json:"error,omitempty"`
}
func (r modifyPackResponse) error() error { return r.Err }
func modifyPackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*modifyPackRequest)
pack, err := svc.ModifyPack(ctx, req.ID, req.PackPayload)
if err != nil {
return modifyPackResponse{Err: err}, nil
}
resp, err := packResponseForPack(ctx, svc, *pack)
if err != nil {
return modifyPackResponse{Err: err}, nil
}
return modifyPackResponse{
Pack: *resp,
}, nil
}
func (svc *Service) ModifyPack(ctx context.Context, id uint, p fleet.PackPayload) (*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return nil, err
}
pack, err := svc.ds.Pack(ctx, id)
if err != nil {
return nil, err
}
if p.Name != nil && pack.EditablePackType() {
pack.Name = *p.Name
}
if p.Description != nil && pack.EditablePackType() {
pack.Description = *p.Description
}
if p.Platform != nil {
pack.Platform = *p.Platform
}
if p.Disabled != nil {
pack.Disabled = *p.Disabled
}
if p.HostIDs != nil && pack.EditablePackType() {
pack.HostIDs = *p.HostIDs
}
if p.LabelIDs != nil && pack.EditablePackType() {
pack.LabelIDs = *p.LabelIDs
}
if p.TeamIDs != nil && pack.EditablePackType() {
pack.TeamIDs = *p.TeamIDs
}
err = svc.ds.SavePack(ctx, pack)
if err != nil {
return nil, err
}
if err := svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeEditedPack,
&map[string]interface{}{"pack_id": pack.ID, "pack_name": pack.Name},
); err != nil {
return nil, err
}
return pack, err
}
////////////////////////////////////////////////////////////////////////////////
// List Packs
////////////////////////////////////////////////////////////////////////////////
type listPacksRequest struct {
ListOptions fleet.ListOptions `url:"list_options"`
}
type listPacksResponse struct {
Packs []packResponse `json:"packs"`
Err error `json:"error,omitempty"`
}
func (r listPacksResponse) error() error { return r.Err }
func listPacksEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*listPacksRequest)
packs, err := svc.ListPacks(ctx, fleet.PackListOptions{ListOptions: req.ListOptions, IncludeSystemPacks: false})
if err != nil {
return getPackResponse{Err: err}, nil
}
resp := listPacksResponse{Packs: make([]packResponse, len(packs))}
for i, pack := range packs {
packResp, err := packResponseForPack(ctx, svc, *pack)
if err != nil {
return getPackResponse{Err: err}, nil
}
resp.Packs[i] = *packResp
}
return resp, nil
}
func (svc *Service) ListPacks(ctx context.Context, opt fleet.PackListOptions) ([]*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListPacks(ctx, opt)
}
////////////////////////////////////////////////////////////////////////////////
// Delete Pack
////////////////////////////////////////////////////////////////////////////////
type deletePackRequest struct {
Name string `url:"name"`
}
type deletePackResponse struct {
Err error `json:"error,omitempty"`
}
func (r deletePackResponse) error() error { return r.Err }
func deletePackEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*deletePackRequest)
err := svc.DeletePack(ctx, req.Name)
if err != nil {
return deletePackResponse{Err: err}, nil
}
return deletePackResponse{}, nil
}
func (svc *Service) DeletePack(ctx context.Context, name string) error {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return err
}
pack, _, err := svc.ds.PackByName(ctx, name)
if err != nil {
return err
}
// if there is a pack by this name, ensure it is not type Global or Team
if pack != nil && !pack.EditablePackType() {
return fmt.Errorf("cannot delete pack_type %s", *pack.Type)
}
if err := svc.ds.DeletePack(ctx, name); err != nil {
return err
}
return svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeDeletedPack,
&map[string]interface{}{"pack_name": name},
)
}
////////////////////////////////////////////////////////////////////////////////
// Delete Pack By ID
////////////////////////////////////////////////////////////////////////////////
type deletePackByIDRequest struct {
ID uint `url:"id"`
}
type deletePackByIDResponse struct {
Err error `json:"error,omitempty"`
}
func (r deletePackByIDResponse) error() error { return r.Err }
func deletePackByIDEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*deletePackByIDRequest)
err := svc.DeletePackByID(ctx, req.ID)
if err != nil {
return deletePackByIDResponse{Err: err}, nil
}
return deletePackByIDResponse{}, nil
}
func (svc *Service) DeletePackByID(ctx context.Context, id uint) error {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return err
}
pack, err := svc.ds.Pack(ctx, id)
if err != nil {
return err
}
if pack != nil && !pack.EditablePackType() {
return fmt.Errorf("cannot delete pack_type %s", *pack.Type)
}
if err := svc.ds.DeletePack(ctx, pack.Name); err != nil {
return err
}
return svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeDeletedPack,
&map[string]interface{}{"pack_name": pack.Name},
)
}
////////////////////////////////////////////////////////////////////////////////
// Apply Pack Spec
////////////////////////////////////////////////////////////////////////////////
type applyPackSpecsRequest struct {
Specs []*fleet.PackSpec `json:"specs"`
}
type applyPackSpecsResponse struct {
Err error `json:"error,omitempty"`
}
func (r applyPackSpecsResponse) error() error { return r.Err }
func applyPackSpecsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*applyPackSpecsRequest)
_, err := svc.ApplyPackSpecs(ctx, req.Specs)
if err != nil {
return applyPackSpecsResponse{Err: err}, nil
}
return applyPackSpecsResponse{}, nil
}
func (svc *Service) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) ([]*fleet.PackSpec, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return nil, err
}
packs, err := svc.ds.ListPacks(ctx, fleet.PackListOptions{IncludeSystemPacks: true})
if err != nil {
return nil, err
}
namePacks := make(map[string]*fleet.Pack, len(packs))
for _, pack := range packs {
namePacks[pack.Name] = pack
}
var result []*fleet.PackSpec
// loop over incoming specs filtering out possible edits to Global or Team Packs
for _, spec := range specs {
// see for known limitations https://github.com/fleetdm/fleet/pull/1558#discussion_r684218301
// check to see if incoming spec is already in the list of packs
if p, ok := namePacks[spec.Name]; ok {
// as long as pack is editable, we'll apply it
if p.EditablePackType() {
result = append(result, spec)
}
} else {
// incoming spec is new, let's apply it
result = append(result, spec)
}
}
if err := svc.ds.ApplyPackSpecs(ctx, result); err != nil {
return nil, err
}
return result, svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeAppliedSpecPack,
&map[string]interface{}{},
)
}
////////////////////////////////////////////////////////////////////////////////
// Get Pack Specs
////////////////////////////////////////////////////////////////////////////////
type getPackSpecsResponse struct {
Specs []*fleet.PackSpec `json:"specs"`
Err error `json:"error,omitempty"`
}
func (r getPackSpecsResponse) error() error { return r.Err }
func getPackSpecsEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
specs, err := svc.GetPackSpecs(ctx)
if err != nil {
return getPackSpecsResponse{Err: err}, nil
}
return getPackSpecsResponse{Specs: specs}, nil
}
func (svc *Service) GetPackSpecs(ctx context.Context) ([]*fleet.PackSpec, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.GetPackSpecs(ctx)
}
////////////////////////////////////////////////////////////////////////////////
// Get Pack Spec
////////////////////////////////////////////////////////////////////////////////
type getPackSpecResponse struct {
Spec *fleet.PackSpec `json:"specs,omitempty"`
Err error `json:"error,omitempty"`
}
func (r getPackSpecResponse) error() error { return r.Err }
func getPackSpecEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*getGenericSpecRequest)
spec, err := svc.GetPackSpec(ctx, req.Name)
if err != nil {
return getPackSpecResponse{Err: err}, nil
}
return getPackSpecResponse{Spec: spec}, nil
}
func (svc *Service) GetPackSpec(ctx context.Context, name string) (*fleet.PackSpec, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.GetPackSpec(ctx, name)
}
////////////////////////////////////////////////////////////////////////////////
// List Packs For Host, not exposed via an endpoint
////////////////////////////////////////////////////////////////////////////////
func (svc *Service) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListPacksForHost(ctx, hid)
}

View File

@ -5,9 +5,12 @@ import (
"testing"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -30,3 +33,270 @@ func TestGetPack(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), authz.ForbiddenErrorMessage)
}
func TestNewPackSavesTargets(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
ds.NewPackFunc = func(ctx context.Context, pack *fleet.Pack, opts ...fleet.OptionalArg) (*fleet.Pack, error) {
return pack, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error {
return nil
}
packPayload := fleet.PackPayload{
Name: ptr.String("foo"),
HostIDs: &[]uint{123},
LabelIDs: &[]uint{456},
TeamIDs: &[]uint{789},
}
pack, err := svc.NewPack(test.UserContext(test.UserAdmin), packPayload)
require.NoError(t, err)
require.Len(t, pack.HostIDs, 1)
require.Len(t, pack.LabelIDs, 1)
require.Len(t, pack.TeamIDs, 1)
assert.Equal(t, uint(123), pack.HostIDs[0])
assert.Equal(t, uint(456), pack.LabelIDs[0])
assert.Equal(t, uint(789), pack.TeamIDs[0])
assert.True(t, ds.NewPackFuncInvoked)
assert.True(t, ds.NewActivityFuncInvoked)
}
func TestPacksWithDS(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
cases := []struct {
name string
fn func(t *testing.T, ds *mysql.Datastore)
}{
{"ModifyPack", testPacksModifyPack},
{"ListPacks", testPacksListPacks},
{"DeletePack", testPacksDeletePack},
{"DeletePackByID", testPacksDeletePackByID},
{"ApplyPackSpecs", testPacksApplyPackSpecs},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
defer mysql.TruncateTables(t, ds)
c.fn(t, ds)
})
}
}
func testPacksModifyPack(t *testing.T, ds *mysql.Datastore) {
svc := newTestService(ds, nil, nil)
test.AddAllHostsLabel(t, ds)
users := createTestUsers(t, ds)
globalPack, err := ds.EnsureGlobalPack(context.Background())
require.NoError(t, err)
labelids := []uint{1, 2, 3}
hostids := []uint{4, 5, 6}
teamids := []uint{7, 8, 9}
packPayload := fleet.PackPayload{
Name: ptr.String("foo"),
Description: ptr.String("bar"),
LabelIDs: &labelids,
HostIDs: &hostids,
TeamIDs: &teamids,
}
user := users["admin1@example.com"]
pack, _ := svc.ModifyPack(test.UserContext(&user), globalPack.ID, packPayload)
require.Equal(t, "Global", pack.Name, "name for global pack should not change")
require.Equal(t, "Global pack", pack.Description, "description for global pack should not change")
require.Len(t, pack.LabelIDs, 1)
require.Len(t, pack.HostIDs, 0)
require.Len(t, pack.TeamIDs, 0)
}
func testPacksListPacks(t *testing.T, ds *mysql.Datastore) {
svc := newTestService(ds, nil, nil)
queries, err := svc.ListPacks(test.UserContext(test.UserAdmin), fleet.PackListOptions{IncludeSystemPacks: false})
require.NoError(t, err)
assert.Len(t, queries, 0)
_, err = ds.NewPack(context.Background(), &fleet.Pack{
Name: "foo",
})
require.NoError(t, err)
queries, err = svc.ListPacks(test.UserContext(test.UserAdmin), fleet.PackListOptions{IncludeSystemPacks: false})
require.NoError(t, err)
assert.Len(t, queries, 1)
}
func testPacksDeletePack(t *testing.T, ds *mysql.Datastore) {
test.AddAllHostsLabel(t, ds)
gp, err := ds.EnsureGlobalPack(context.Background())
require.NoError(t, err)
users := createTestUsers(t, ds)
user := users["admin1@example.com"]
team1, err := ds.NewTeam(context.Background(), &fleet.Team{
ID: 42,
Name: "team1",
Description: "desc team1",
})
require.NoError(t, err)
tp, err := ds.EnsureTeamPack(context.Background(), team1.ID)
require.NoError(t, err)
type args struct {
ctx context.Context
name string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "cannot delete global pack",
args: args{
ctx: test.UserContext(&user),
name: gp.Name,
},
wantErr: true,
},
{
name: "cannot delete team pack",
args: args{
ctx: test.UserContext(&user),
name: tp.Name,
},
wantErr: true,
},
{
name: "delete pack that doesn't exist",
args: args{
ctx: test.UserContext(&user),
name: "foo",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := newTestService(ds, nil, nil)
if err := svc.DeletePack(tt.args.ctx, tt.args.name); (err != nil) != tt.wantErr {
t.Errorf("DeletePack() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func testPacksDeletePackByID(t *testing.T, ds *mysql.Datastore) {
test.AddAllHostsLabel(t, ds)
globalPack, err := ds.EnsureGlobalPack(context.Background())
require.NoError(t, err)
type args struct {
ctx context.Context
id uint
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "cannot delete global pack",
args: args{
ctx: test.UserContext(test.UserAdmin),
id: globalPack.ID,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := newTestService(ds, nil, nil)
if err := svc.DeletePackByID(tt.args.ctx, tt.args.id); (err != nil) != tt.wantErr {
t.Errorf("DeletePackByID() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func testPacksApplyPackSpecs(t *testing.T, ds *mysql.Datastore) {
test.AddAllHostsLabel(t, ds)
global, err := ds.EnsureGlobalPack(context.Background())
require.NoError(t, err)
users := createTestUsers(t, ds)
user := users["admin1@example.com"]
team1, err := ds.NewTeam(context.Background(), &fleet.Team{
ID: 42,
Name: "team1",
Description: "desc team1",
})
require.NoError(t, err)
teamPack, err := ds.EnsureTeamPack(context.Background(), team1.ID)
require.NoError(t, err)
type args struct {
ctx context.Context
specs []*fleet.PackSpec
}
tests := []struct {
name string
args args
want []*fleet.PackSpec
wantErr bool
}{
{
name: "cannot modify global pack",
args: args{
ctx: test.UserContext(&user),
specs: []*fleet.PackSpec{
{Name: global.Name, Description: "bar", Platform: "baz"},
{Name: "Foo Pack", Description: "Foo Desc", Platform: "MacOS"},
{Name: "Bar Pack", Description: "Bar Desc", Platform: "MacOS"},
},
},
want: []*fleet.PackSpec{
{Name: "Foo Pack", Description: "Foo Desc", Platform: "MacOS"},
{Name: "Bar Pack", Description: "Bar Desc", Platform: "MacOS"},
},
wantErr: false,
},
{
name: "cannot modify team pack",
args: args{
ctx: test.UserContext(&user),
specs: []*fleet.PackSpec{
{Name: teamPack.Name, Description: "Desc", Platform: "windows"},
{Name: "Test", Description: "Test Desc", Platform: "linux"},
},
},
want: []*fleet.PackSpec{
{Name: "Test", Description: "Test Desc", Platform: "linux"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := newTestService(ds, nil, nil)
got, err := svc.ApplyPackSpecs(tt.args.ctx, tt.args.specs)
if (err != nil) != tt.wantErr {
t.Errorf("ApplyPackSpecs() error = %v, wantErr %v", err, tt.wantErr)
return
}
require.Equal(t, tt.want, got)
})
}
}

View File

@ -1,242 +0,0 @@
package service
import (
"context"
"fmt"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/fleet"
)
func (svc *Service) ApplyPackSpecs(ctx context.Context, specs []*fleet.PackSpec) ([]*fleet.PackSpec, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return nil, err
}
packs, err := svc.ds.ListPacks(ctx, fleet.PackListOptions{IncludeSystemPacks: true})
if err != nil {
return nil, err
}
namePacks := make(map[string]*fleet.Pack, len(packs))
for _, pack := range packs {
namePacks[pack.Name] = pack
}
var result []*fleet.PackSpec
// loop over incoming specs filtering out possible edits to Global or Team Packs
for _, spec := range specs {
// see for known limitations https://github.com/fleetdm/fleet/pull/1558#discussion_r684218301
// check to see if incoming spec is already in the list of packs
if p, ok := namePacks[spec.Name]; ok {
// as long as pack is editable, we'll apply it
if p.EditablePackType() {
result = append(result, spec)
}
} else {
// incoming spec is new, let's apply it
result = append(result, spec)
}
}
if err := svc.ds.ApplyPackSpecs(ctx, result); err != nil {
return nil, err
}
return result, svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeAppliedSpecPack,
&map[string]interface{}{},
)
}
func (svc *Service) GetPackSpecs(ctx context.Context) ([]*fleet.PackSpec, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.GetPackSpecs(ctx)
}
func (svc *Service) GetPackSpec(ctx context.Context, name string) (*fleet.PackSpec, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.GetPackSpec(ctx, name)
}
func (svc *Service) ListPacks(ctx context.Context, opt fleet.PackListOptions) ([]*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListPacks(ctx, opt)
}
func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return nil, err
}
var pack fleet.Pack
if p.Name != nil {
pack.Name = *p.Name
}
if p.Description != nil {
pack.Description = *p.Description
}
if p.Platform != nil {
pack.Platform = *p.Platform
}
if p.Disabled != nil {
pack.Disabled = *p.Disabled
}
if p.HostIDs != nil {
pack.HostIDs = *p.HostIDs
}
if p.LabelIDs != nil {
pack.LabelIDs = *p.LabelIDs
}
if p.TeamIDs != nil {
pack.TeamIDs = *p.TeamIDs
}
_, err := svc.ds.NewPack(ctx, &pack)
if err != nil {
return nil, err
}
if err := svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeCreatedPack,
&map[string]interface{}{"pack_id": pack.ID, "pack_name": pack.Name},
); err != nil {
return nil, err
}
return &pack, nil
}
func (svc *Service) ModifyPack(ctx context.Context, id uint, p fleet.PackPayload) (*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return nil, err
}
pack, err := svc.ds.Pack(ctx, id)
if err != nil {
return nil, err
}
if p.Name != nil && pack.EditablePackType() {
pack.Name = *p.Name
}
if p.Description != nil && pack.EditablePackType() {
pack.Description = *p.Description
}
if p.Platform != nil {
pack.Platform = *p.Platform
}
if p.Disabled != nil {
pack.Disabled = *p.Disabled
}
if p.HostIDs != nil && pack.EditablePackType() {
pack.HostIDs = *p.HostIDs
}
if p.LabelIDs != nil && pack.EditablePackType() {
pack.LabelIDs = *p.LabelIDs
}
if p.TeamIDs != nil && pack.EditablePackType() {
pack.TeamIDs = *p.TeamIDs
}
err = svc.ds.SavePack(ctx, pack)
if err != nil {
return nil, err
}
if err := svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeEditedPack,
&map[string]interface{}{"pack_id": pack.ID, "pack_name": pack.Name},
); err != nil {
return nil, err
}
return pack, err
}
func (svc *Service) DeletePack(ctx context.Context, name string) error {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return err
}
pack, _, err := svc.ds.PackByName(ctx, name)
if err != nil {
return err
}
// if there is a pack by this name, ensure it is not type Global or Team
if pack != nil && !pack.EditablePackType() {
return fmt.Errorf("cannot delete pack_type %s", *pack.Type)
}
if err := svc.ds.DeletePack(ctx, name); err != nil {
return err
}
return svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeDeletedPack,
&map[string]interface{}{"pack_name": name},
)
}
func (svc *Service) DeletePackByID(ctx context.Context, id uint) error {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return err
}
pack, err := svc.ds.Pack(ctx, id)
if err != nil {
return err
}
if pack != nil && !pack.EditablePackType() {
return fmt.Errorf("cannot delete pack_type %s", *pack.Type)
}
if err := svc.ds.DeletePack(ctx, pack.Name); err != nil {
return err
}
return svc.ds.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeDeletedPack,
&map[string]interface{}{"pack_name": pack.Name},
)
}
func (svc *Service) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListPacksForHost(ctx, hid)
}

View File

@ -1,299 +0,0 @@
package service
import (
"context"
"testing"
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServiceListPacks(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc := newTestService(ds, nil, nil)
queries, err := svc.ListPacks(test.UserContext(test.UserAdmin), fleet.PackListOptions{IncludeSystemPacks: false})
assert.Nil(t, err)
assert.Len(t, queries, 0)
_, err = ds.NewPack(context.Background(), &fleet.Pack{
Name: "foo",
})
assert.Nil(t, err)
queries, err = svc.ListPacks(test.UserContext(test.UserAdmin), fleet.PackListOptions{IncludeSystemPacks: false})
assert.Nil(t, err)
assert.Len(t, queries, 1)
}
func TestNewSavesTargets(t *testing.T) {
ds := new(mock.Store)
svc := newTestService(ds, nil, nil)
ds.NewPackFunc = func(ctx context.Context, pack *fleet.Pack, opts ...fleet.OptionalArg) (*fleet.Pack, error) {
return pack, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activityType string, details *map[string]interface{}) error {
return nil
}
packPayload := fleet.PackPayload{
Name: ptr.String("foo"),
HostIDs: &[]uint{123},
LabelIDs: &[]uint{456},
TeamIDs: &[]uint{789},
}
pack, _ := svc.NewPack(test.UserContext(test.UserAdmin), packPayload)
require.Len(t, pack.HostIDs, 1)
require.Len(t, pack.LabelIDs, 1)
require.Len(t, pack.TeamIDs, 1)
assert.Equal(t, uint(123), pack.HostIDs[0])
assert.Equal(t, uint(456), pack.LabelIDs[0])
assert.Equal(t, uint(789), pack.TeamIDs[0])
}
func TestService_ModifyPack_GlobalPack(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
svc := newTestService(ds, nil, nil)
test.AddAllHostsLabel(t, ds)
users := createTestUsers(t, ds)
globalPack, err := ds.EnsureGlobalPack(context.Background())
require.NoError(t, err)
labelids := []uint{1, 2, 3}
hostids := []uint{4, 5, 6}
teamids := []uint{7, 8, 9}
packPayload := fleet.PackPayload{
Name: ptr.String("foo"),
Description: ptr.String("bar"),
LabelIDs: &labelids,
HostIDs: &hostids,
TeamIDs: &teamids,
}
user := users["admin1@example.com"]
pack, _ := svc.ModifyPack(test.UserContext(&user), globalPack.ID, packPayload)
require.Equal(t, "Global", pack.Name, "name for global pack should not change")
require.Equal(t, "Global pack", pack.Description, "description for global pack should not change")
require.Len(t, pack.LabelIDs, 1)
require.Len(t, pack.HostIDs, 0)
require.Len(t, pack.TeamIDs, 0)
}
func TestService_DeletePackByID_GlobalPack(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
test.AddAllHostsLabel(t, ds)
globalPack, err := ds.EnsureGlobalPack(context.Background())
require.NoError(t, err)
type fields struct {
ds fleet.Datastore
}
type args struct {
ctx context.Context
id uint
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
name: "cannot delete global pack",
fields: fields{
ds,
},
args: args{
ctx: test.UserContext(test.UserAdmin),
id: globalPack.ID,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := newTestService(tt.fields.ds, nil, nil)
if err := svc.DeletePackByID(tt.args.ctx, tt.args.id); (err != nil) != tt.wantErr {
t.Errorf("DeletePackByID() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestService_ApplyPackSpecs(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
test.AddAllHostsLabel(t, ds)
global, err := ds.EnsureGlobalPack(context.Background())
require.NoError(t, err)
users := createTestUsers(t, ds)
user := users["admin1@example.com"]
team1, err := ds.NewTeam(context.Background(), &fleet.Team{
ID: 42,
Name: "team1",
Description: "desc team1",
})
require.NoError(t, err)
teamPack, err := ds.EnsureTeamPack(context.Background(), team1.ID)
require.NoError(t, err)
type fields struct {
ds fleet.Datastore
}
type args struct {
ctx context.Context
specs []*fleet.PackSpec
}
tests := []struct {
name string
fields fields
args args
want []*fleet.PackSpec
wantErr bool
}{
{
name: "cannot modify global pack",
fields: fields{
ds,
},
args: args{
ctx: test.UserContext(&user),
specs: []*fleet.PackSpec{
{Name: global.Name, Description: "bar", Platform: "baz"},
{Name: "Foo Pack", Description: "Foo Desc", Platform: "MacOS"},
{Name: "Bar Pack", Description: "Bar Desc", Platform: "MacOS"},
},
},
want: []*fleet.PackSpec{
{Name: "Foo Pack", Description: "Foo Desc", Platform: "MacOS"},
{Name: "Bar Pack", Description: "Bar Desc", Platform: "MacOS"},
},
wantErr: false,
},
{
name: "cannot modify team pack",
fields: fields{
ds,
},
args: args{
ctx: test.UserContext(&user),
specs: []*fleet.PackSpec{
{Name: teamPack.Name, Description: "Desc", Platform: "windows"},
{Name: "Test", Description: "Test Desc", Platform: "linux"},
},
},
want: []*fleet.PackSpec{
{Name: "Test", Description: "Test Desc", Platform: "linux"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := newTestService(tt.fields.ds, nil, nil)
got, err := svc.ApplyPackSpecs(tt.args.ctx, tt.args.specs)
if (err != nil) != tt.wantErr {
t.Errorf("ApplyPackSpecs() error = %v, wantErr %v", err, tt.wantErr)
return
}
require.Equal(t, tt.want, got)
})
}
}
func TestService_DeletePack(t *testing.T) {
ds := mysql.CreateMySQLDS(t)
defer ds.Close()
test.AddAllHostsLabel(t, ds)
gp, err := ds.EnsureGlobalPack(context.Background())
require.NoError(t, err)
users := createTestUsers(t, ds)
user := users["admin1@example.com"]
team1, err := ds.NewTeam(context.Background(), &fleet.Team{
ID: 42,
Name: "team1",
Description: "desc team1",
})
require.NoError(t, err)
tp, err := ds.EnsureTeamPack(context.Background(), team1.ID)
require.NoError(t, err)
type fields struct {
ds fleet.Datastore
}
type args struct {
ctx context.Context
name string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
name: "cannot delete global pack",
fields: fields{
ds: ds,
},
args: args{
ctx: test.UserContext(&user),
name: gp.Name,
},
wantErr: true,
},
{
name: "cannot delete team pack",
fields: fields{
ds: ds,
},
args: args{
ctx: test.UserContext(&user),
name: tp.Name,
},
wantErr: true,
},
{
name: "delete pack that doesn't exist",
fields: fields{
ds: ds,
},
args: args{
ctx: test.UserContext(&user),
name: "foo",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := newTestService(tt.fields.ds, nil, nil)
if err := svc.DeletePack(tt.args.ctx, tt.args.name); (err != nil) != tt.wantErr {
t.Errorf("DeletePack() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@ -299,7 +299,7 @@ func decodeNoParamsRequest(ctx context.Context, r *http.Request) (interface{}, e
}
type getGenericSpecRequest struct {
Name string
Name string `url:"name"`
}
func decodeGetGenericSpecRequest(ctx context.Context, r *http.Request) (interface{}, error) {

View File

@ -1,76 +0,0 @@
package service
import (
"context"
"encoding/json"
"net/http"
)
func decodeCreatePackRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req createPackRequest
if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil {
return nil, err
}
return req, nil
}
func decodeModifyPackRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, err
}
var req modifyPackRequest
if err := json.NewDecoder(r.Body).Decode(&req.payload); err != nil {
return nil, err
}
req.ID = uint(id)
return req, nil
}
func decodeDeletePackRequest(ctx context.Context, r *http.Request) (interface{}, error) {
name, err := stringFromRequest(r, "name")
if err != nil {
return nil, err
}
var req deletePackRequest
req.Name = name
return req, nil
}
func decodeDeletePackByIDRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, err
}
var req deletePackByIDRequest
req.ID = uint(id)
return req, nil
}
func decodeGetPackRequest(ctx context.Context, r *http.Request) (interface{}, error) {
id, err := uintFromRequest(r, "id")
if err != nil {
return nil, err
}
var req getPackRequest
req.ID = uint(id)
return req, nil
}
func decodeListPacksRequest(ctx context.Context, r *http.Request) (interface{}, error) {
opt, err := listOptionsFromRequest(r)
if err != nil {
return nil, err
}
return listPacksRequest{ListOptions: opt}, nil
}
func decodeApplyPackSpecsRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var req applyPackSpecsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}
return req, nil
}

View File

@ -1,104 +0,0 @@
package service
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDecodeCreatePackRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/packs", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeCreatePackRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(createPackRequest)
assert.Equal(t, "foo", *params.payload.Name)
assert.Equal(t, "bar", *params.payload.Description)
require.NotNil(t, params.payload.HostIDs)
assert.Len(t, *params.payload.HostIDs, 3)
require.NotNil(t, params.payload.LabelIDs)
assert.Len(t, *params.payload.LabelIDs, 2)
}).Methods("POST")
var body bytes.Buffer
body.Write([]byte(`{
"name": "foo",
"description": "bar",
"host_ids": [1, 2, 3],
"label_ids": [1, 5]
}`))
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("POST", "/api/v1/fleet/packs", &body),
)
}
func TestDecodeModifyPackRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/packs/{id}", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeModifyPackRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(modifyPackRequest)
assert.Equal(t, uint(1), params.ID)
assert.Equal(t, "foo", *params.payload.Name)
assert.Equal(t, "bar", *params.payload.Description)
require.NotNil(t, params.payload.HostIDs)
assert.Len(t, *params.payload.HostIDs, 3)
require.NotNil(t, params.payload.LabelIDs)
assert.Len(t, *params.payload.LabelIDs, 2)
}).Methods("PATCH")
var body bytes.Buffer
body.Write([]byte(`{
"name": "foo",
"description": "bar",
"host_ids": [1, 2, 3],
"label_ids": [1, 5]
}`))
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("PATCH", "/api/v1/fleet/packs/1", &body),
)
}
func TestDecodeDeletePackRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/packs/{name}", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeDeletePackRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(deletePackRequest)
assert.Equal(t, "packaday", params.Name)
}).Methods("DELETE")
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("DELETE", "/api/v1/fleet/packs/packaday", nil),
)
}
func TestDecodeGetPackRequest(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/v1/fleet/packs/{id}", func(writer http.ResponseWriter, request *http.Request) {
r, err := decodeGetPackRequest(context.Background(), request)
assert.Nil(t, err)
params := r.(getPackRequest)
assert.Equal(t, uint(1), params.ID)
}).Methods("GET")
router.ServeHTTP(
httptest.NewRecorder(),
httptest.NewRequest("GET", "/api/v1/fleet/packs/1", nil),
)
}