mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 00:45:19 +00:00
feat: reset yes/no count when query changes (#14776)
# Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [x] Added/updated tests - [x] Manual QA for all new/changed functionality
This commit is contained in:
parent
8162d052bf
commit
b265d56d73
@ -108,7 +108,7 @@ func policyDB(ctx context.Context, q sqlx.QueryerContext, id uint, teamID *uint)
|
||||
// SavePolicy updates some fields of the given policy on the datastore.
|
||||
//
|
||||
// Currently SavePolicy does not allow updating the team of an existing policy.
|
||||
func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy) error {
|
||||
func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error {
|
||||
sql := `
|
||||
UPDATE policies
|
||||
SET name = ?, query = ?, description = ?, resolution = ?, platforms = ?, critical = ?
|
||||
@ -126,6 +126,9 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy) error {
|
||||
return ctxerr.Wrap(ctx, notFound("Policy").WithID(p.ID))
|
||||
}
|
||||
|
||||
if shouldRemoveAllPolicyMemberships {
|
||||
return cleanupPolicyMembership(ctx, ds.writer(ctx), p.ID)
|
||||
}
|
||||
return cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.writer(ctx), p.ID, p.Platform)
|
||||
}
|
||||
|
||||
@ -782,6 +785,25 @@ func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerCo
|
||||
return ctxerr.Wrap(ctx, err, "cleanup policy membership")
|
||||
}
|
||||
|
||||
// cleanupPolicyMembership is similar to cleanupPolicyMembershipOnPolicyUpdate but without the platform constraints.
|
||||
// Used when we want to remove all policy membership.
|
||||
func cleanupPolicyMembership(ctx context.Context, db sqlx.ExecerContext, policyID uint) error {
|
||||
delStmt := `
|
||||
DELETE
|
||||
pm
|
||||
FROM
|
||||
policy_membership pm
|
||||
LEFT JOIN
|
||||
hosts h
|
||||
ON
|
||||
pm.host_id = h.id
|
||||
WHERE
|
||||
pm.policy_id = ?`
|
||||
|
||||
_, err := db.ExecContext(ctx, delStmt, policyID)
|
||||
return ctxerr.Wrap(ctx, err, "cleanup policy membership")
|
||||
}
|
||||
|
||||
// CleanupPolicyMembership deletes the host's membership from policies that
|
||||
// have been updated recently if those hosts don't meet the policy's criteria
|
||||
// anymore (e.g. if the policy's platforms has been updated from "any" - the
|
||||
|
@ -1321,7 +1321,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
|
||||
Name: "non-existent query",
|
||||
Query: "select 1;",
|
||||
},
|
||||
})
|
||||
}, false)
|
||||
require.Error(t, err)
|
||||
var nfe *notFoundError
|
||||
require.True(t, errors.As(err, &nfe))
|
||||
@ -1359,7 +1359,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
|
||||
gp2 := *gp
|
||||
gp2.Name = "global query updated"
|
||||
gp2.Critical = true
|
||||
err = ds.SavePolicy(ctx, &gp2)
|
||||
err = ds.SavePolicy(ctx, &gp2, false)
|
||||
require.NoError(t, err)
|
||||
gp, err = ds.Policy(ctx, gp.ID)
|
||||
require.NoError(t, err)
|
||||
@ -1373,12 +1373,24 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
|
||||
tp2.Description = "team1 query desc updated"
|
||||
tp2.Resolution = ptr.String("team1 query resolution updated")
|
||||
tp2.Critical = false
|
||||
err = ds.SavePolicy(ctx, &tp2)
|
||||
err = ds.SavePolicy(ctx, &tp2, true)
|
||||
require.NoError(t, err)
|
||||
tp1, err = ds.Policy(ctx, tp1.ID)
|
||||
tp2.UpdateCreateTimestamps = tp1.UpdateCreateTimestamps
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tp1, &tp2)
|
||||
|
||||
loadMembershipStmt, args, err := sqlx.In(`SELECT policy_id, host_id FROM policy_membership WHERE policy_id = ?`, tp2.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
type polHostIDs struct {
|
||||
PolicyID uint `db:"policy_id"`
|
||||
HostID uint `db:"host_id"`
|
||||
}
|
||||
var rows []polHostIDs
|
||||
err = ds.writer(context.Background()).SelectContext(context.Background(), &rows, loadMembershipStmt, args...)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rows, 0)
|
||||
}
|
||||
|
||||
func testPoliciesDelUser(t *testing.T, ds *Datastore) {
|
||||
@ -1712,9 +1724,9 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
|
||||
}
|
||||
|
||||
// updating without change works fine
|
||||
err = ds.SavePolicy(ctx, polsByName["g1"])
|
||||
err = ds.SavePolicy(ctx, polsByName["g1"], false)
|
||||
require.NoError(t, err)
|
||||
err = ds.SavePolicy(ctx, polsByName["t2"])
|
||||
err = ds.SavePolicy(ctx, polsByName["t2"], false)
|
||||
require.NoError(t, err)
|
||||
// apply specs that result in an update (without change) works fine
|
||||
err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{
|
||||
@ -1766,7 +1778,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
|
||||
g1 := polsByName["g1"]
|
||||
g1.Platform = "linux"
|
||||
polsByName["g1"] = g1
|
||||
err = ds.SavePolicy(ctx, g1)
|
||||
err = ds.SavePolicy(ctx, g1, false)
|
||||
require.NoError(t, err)
|
||||
wantHostsByPol["g1"] = []uint{globalHosts[hostDeb].ID, globalHosts[hostLin].ID}
|
||||
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
|
||||
@ -1775,7 +1787,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
|
||||
t1 := polsByName["t1"]
|
||||
t1.Platform = "windows,darwin"
|
||||
polsByName["t1"] = t1
|
||||
err = ds.SavePolicy(ctx, t1)
|
||||
err = ds.SavePolicy(ctx, t1, false)
|
||||
require.NoError(t, err)
|
||||
wantHostsByPol["t1"] = []uint{teamHosts[hostWin].ID, teamHosts[hostMac].ID}
|
||||
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
|
||||
|
@ -521,7 +521,7 @@ type Datastore interface {
|
||||
// SavePolicy updates some fields of the given policy on the datastore.
|
||||
//
|
||||
// It is also used to update team policies.
|
||||
SavePolicy(ctx context.Context, p *Policy) error
|
||||
SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool) error
|
||||
|
||||
ListGlobalPolicies(ctx context.Context, opts ListOptions) ([]*Policy, error)
|
||||
PoliciesByID(ctx context.Context, ids []uint) (map[uint]*Policy, error)
|
||||
|
@ -382,7 +382,7 @@ type PolicyFunc func(ctx context.Context, id uint) (*fleet.Policy, error)
|
||||
|
||||
type PolicyByNameFunc func(ctx context.Context, name string) (*fleet.Policy, error)
|
||||
|
||||
type SavePolicyFunc func(ctx context.Context, p *fleet.Policy) error
|
||||
type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error
|
||||
|
||||
type ListGlobalPoliciesFunc func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error)
|
||||
|
||||
@ -3041,11 +3041,11 @@ func (s *DataStore) PolicyByName(ctx context.Context, name string) (*fleet.Polic
|
||||
return s.PolicyByNameFunc(ctx, name)
|
||||
}
|
||||
|
||||
func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy) error {
|
||||
func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error {
|
||||
s.mu.Lock()
|
||||
s.SavePolicyFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.SavePolicyFunc(ctx, p)
|
||||
return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships)
|
||||
}
|
||||
|
||||
func (s *DataStore) ListGlobalPolicies(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) {
|
||||
|
@ -68,7 +68,7 @@ func TestGlobalPoliciesAuth(t *testing.T) {
|
||||
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
|
||||
return nil
|
||||
}
|
||||
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy) error {
|
||||
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error {
|
||||
return nil
|
||||
}
|
||||
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
|
@ -1847,6 +1847,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
|
||||
require.NotNil(t, mgpResp.Policy.Resolution)
|
||||
assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution)
|
||||
assert.Equal(t, "darwin", mgpResp.Policy.Platform)
|
||||
assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount)
|
||||
assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount)
|
||||
|
||||
ggpResp := getPolicyByIDResponse{}
|
||||
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), getPolicyByIDRequest{}, http.StatusOK, &ggpResp)
|
||||
@ -1857,6 +1859,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
|
||||
require.NotNil(t, ggpResp.Policy.Resolution)
|
||||
assert.Equal(t, "some global resolution updated", *ggpResp.Policy.Resolution)
|
||||
assert.Equal(t, "darwin", mgpResp.Policy.Platform)
|
||||
assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount)
|
||||
assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount)
|
||||
|
||||
policiesResponse := listGlobalPoliciesResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse)
|
||||
@ -1867,6 +1871,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
|
||||
require.NotNil(t, policiesResponse.Policies[0].Resolution)
|
||||
assert.Equal(t, "some global resolution updated", *policiesResponse.Policies[0].Resolution)
|
||||
assert.Equal(t, "darwin", policiesResponse.Policies[0].Platform)
|
||||
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
|
||||
assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount)
|
||||
|
||||
listHostsURL := fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d", policiesResponse.Policies[0].ID)
|
||||
listHostsResp := listHostsResponse{}
|
||||
@ -1880,6 +1886,11 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
|
||||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false))
|
||||
|
||||
@ -1888,6 +1899,45 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
|
||||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 1)
|
||||
|
||||
mgpParams = modifyGlobalPolicyRequest{
|
||||
ModifyPolicyPayload: fleet.ModifyPolicyPayload{
|
||||
Query: ptr.String("select * from users;"),
|
||||
},
|
||||
}
|
||||
mgpResp = modifyGlobalPolicyResponse{}
|
||||
s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), mgpParams, http.StatusOK, &mgpResp)
|
||||
require.NotNil(t, gpResp.Policy)
|
||||
assert.Equal(t, "TestQuery4", mgpResp.Policy.Name)
|
||||
assert.Equal(t, "select * from users;", mgpResp.Policy.Query)
|
||||
assert.Equal(t, "Some description updated", mgpResp.Policy.Description)
|
||||
require.NotNil(t, mgpResp.Policy.Resolution)
|
||||
assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution)
|
||||
assert.Equal(t, "darwin", mgpResp.Policy.Platform)
|
||||
assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount)
|
||||
assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount)
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
policiesResponse = listGlobalPoliciesResponse{}
|
||||
s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse)
|
||||
require.Len(t, policiesResponse.Policies, 1)
|
||||
assert.Equal(t, "TestQuery4", policiesResponse.Policies[0].Name)
|
||||
assert.Equal(t, "select * from users;", policiesResponse.Policies[0].Query)
|
||||
assert.Equal(t, "Some description updated", policiesResponse.Policies[0].Description)
|
||||
require.NotNil(t, policiesResponse.Policies[0].Resolution)
|
||||
assert.Equal(t, "some global resolution updated", *policiesResponse.Policies[0].Resolution)
|
||||
assert.Equal(t, "darwin", policiesResponse.Policies[0].Platform)
|
||||
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
|
||||
assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount)
|
||||
|
||||
deletePolicyParams := deleteGlobalPoliciesRequest{IDs: []uint{policiesResponse.Policies[0].ID}}
|
||||
deletePolicyResp := deleteGlobalPoliciesResponse{}
|
||||
s.DoJSON("POST", "/api/latest/fleet/policies/delete", deletePolicyParams, http.StatusOK, &deletePolicyResp)
|
||||
|
@ -366,6 +366,7 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
|
||||
})
|
||||
}
|
||||
|
||||
var shouldRemoveAll bool
|
||||
if p.Name != nil {
|
||||
policy.Name = *p.Name
|
||||
}
|
||||
@ -373,6 +374,11 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
|
||||
policy.Description = *p.Description
|
||||
}
|
||||
if p.Query != nil {
|
||||
if policy.Query != *p.Query {
|
||||
shouldRemoveAll = true
|
||||
policy.FailingHostCount = 0
|
||||
policy.PassingHostCount = 0
|
||||
}
|
||||
policy.Query = *p.Query
|
||||
}
|
||||
if p.Resolution != nil {
|
||||
@ -386,7 +392,7 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
|
||||
}
|
||||
logging.WithExtras(ctx, "name", policy.Name, "sql", policy.Query)
|
||||
|
||||
err = svc.ds.SavePolicy(ctx, policy)
|
||||
err = svc.ds.SavePolicy(ctx, policy, shouldRemoveAll)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "saving policy")
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ func TestTeamPoliciesAuth(t *testing.T) {
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy) error {
|
||||
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error {
|
||||
return nil
|
||||
}
|
||||
ds.DeleteTeamPoliciesFunc = func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {
|
||||
|
Loading…
Reference in New Issue
Block a user