feat: reset yes/no count when query changes (#14776)

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality
This commit is contained in:
Jahziel Villasana-Espinoza 2023-10-31 12:29:09 -04:00 committed by GitHub
parent 8162d052bf
commit b265d56d73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 105 additions and 15 deletions

View File

@ -108,7 +108,7 @@ func policyDB(ctx context.Context, q sqlx.QueryerContext, id uint, teamID *uint)
// SavePolicy updates some fields of the given policy on the datastore.
//
// Currently SavePolicy does not allow updating the team of an existing policy.
func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy) error {
func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error {
sql := `
UPDATE policies
SET name = ?, query = ?, description = ?, resolution = ?, platforms = ?, critical = ?
@ -126,6 +126,9 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy) error {
return ctxerr.Wrap(ctx, notFound("Policy").WithID(p.ID))
}
if shouldRemoveAllPolicyMemberships {
return cleanupPolicyMembership(ctx, ds.writer(ctx), p.ID)
}
return cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.writer(ctx), p.ID, p.Platform)
}
@ -782,6 +785,25 @@ func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerCo
return ctxerr.Wrap(ctx, err, "cleanup policy membership")
}
// cleanupPolicyMembership is similar to cleanupPolicyMembershipOnPolicyUpdate but without the platform constraints.
// Used when we want to remove all policy membership.
func cleanupPolicyMembership(ctx context.Context, db sqlx.ExecerContext, policyID uint) error {
delStmt := `
DELETE
pm
FROM
policy_membership pm
LEFT JOIN
hosts h
ON
pm.host_id = h.id
WHERE
pm.policy_id = ?`
_, err := db.ExecContext(ctx, delStmt, policyID)
return ctxerr.Wrap(ctx, err, "cleanup policy membership")
}
// CleanupPolicyMembership deletes the host's membership from policies that
// have been updated recently if those hosts don't meet the policy's criteria
// anymore (e.g. if the policy's platforms has been updated from "any" - the

View File

@ -1321,7 +1321,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
Name: "non-existent query",
Query: "select 1;",
},
})
}, false)
require.Error(t, err)
var nfe *notFoundError
require.True(t, errors.As(err, &nfe))
@ -1359,7 +1359,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
gp2 := *gp
gp2.Name = "global query updated"
gp2.Critical = true
err = ds.SavePolicy(ctx, &gp2)
err = ds.SavePolicy(ctx, &gp2, false)
require.NoError(t, err)
gp, err = ds.Policy(ctx, gp.ID)
require.NoError(t, err)
@ -1373,12 +1373,24 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
tp2.Description = "team1 query desc updated"
tp2.Resolution = ptr.String("team1 query resolution updated")
tp2.Critical = false
err = ds.SavePolicy(ctx, &tp2)
err = ds.SavePolicy(ctx, &tp2, true)
require.NoError(t, err)
tp1, err = ds.Policy(ctx, tp1.ID)
tp2.UpdateCreateTimestamps = tp1.UpdateCreateTimestamps
require.NoError(t, err)
require.Equal(t, tp1, &tp2)
loadMembershipStmt, args, err := sqlx.In(`SELECT policy_id, host_id FROM policy_membership WHERE policy_id = ?`, tp2.ID)
require.NoError(t, err)
type polHostIDs struct {
PolicyID uint `db:"policy_id"`
HostID uint `db:"host_id"`
}
var rows []polHostIDs
err = ds.writer(context.Background()).SelectContext(context.Background(), &rows, loadMembershipStmt, args...)
require.NoError(t, err)
require.Len(t, rows, 0)
}
func testPoliciesDelUser(t *testing.T, ds *Datastore) {
@ -1712,9 +1724,9 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
}
// updating without change works fine
err = ds.SavePolicy(ctx, polsByName["g1"])
err = ds.SavePolicy(ctx, polsByName["g1"], false)
require.NoError(t, err)
err = ds.SavePolicy(ctx, polsByName["t2"])
err = ds.SavePolicy(ctx, polsByName["t2"], false)
require.NoError(t, err)
// apply specs that result in an update (without change) works fine
err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{
@ -1766,7 +1778,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
g1 := polsByName["g1"]
g1.Platform = "linux"
polsByName["g1"] = g1
err = ds.SavePolicy(ctx, g1)
err = ds.SavePolicy(ctx, g1, false)
require.NoError(t, err)
wantHostsByPol["g1"] = []uint{globalHosts[hostDeb].ID, globalHosts[hostLin].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
@ -1775,7 +1787,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
t1 := polsByName["t1"]
t1.Platform = "windows,darwin"
polsByName["t1"] = t1
err = ds.SavePolicy(ctx, t1)
err = ds.SavePolicy(ctx, t1, false)
require.NoError(t, err)
wantHostsByPol["t1"] = []uint{teamHosts[hostWin].ID, teamHosts[hostMac].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)

View File

@ -521,7 +521,7 @@ type Datastore interface {
// SavePolicy updates some fields of the given policy on the datastore.
//
// It is also used to update team policies.
SavePolicy(ctx context.Context, p *Policy) error
SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool) error
ListGlobalPolicies(ctx context.Context, opts ListOptions) ([]*Policy, error)
PoliciesByID(ctx context.Context, ids []uint) (map[uint]*Policy, error)

View File

@ -382,7 +382,7 @@ type PolicyFunc func(ctx context.Context, id uint) (*fleet.Policy, error)
type PolicyByNameFunc func(ctx context.Context, name string) (*fleet.Policy, error)
type SavePolicyFunc func(ctx context.Context, p *fleet.Policy) error
type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error
type ListGlobalPoliciesFunc func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error)
@ -3041,11 +3041,11 @@ func (s *DataStore) PolicyByName(ctx context.Context, name string) (*fleet.Polic
return s.PolicyByNameFunc(ctx, name)
}
func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy) error {
func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error {
s.mu.Lock()
s.SavePolicyFuncInvoked = true
s.mu.Unlock()
return s.SavePolicyFunc(ctx, p)
return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships)
}
func (s *DataStore) ListGlobalPolicies(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) {

View File

@ -68,7 +68,7 @@ func TestGlobalPoliciesAuth(t *testing.T) {
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
return nil
}
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy) error {
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {

View File

@ -1847,6 +1847,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
require.NotNil(t, mgpResp.Policy.Resolution)
assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution)
assert.Equal(t, "darwin", mgpResp.Policy.Platform)
assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount)
assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount)
ggpResp := getPolicyByIDResponse{}
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), getPolicyByIDRequest{}, http.StatusOK, &ggpResp)
@ -1857,6 +1859,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
require.NotNil(t, ggpResp.Policy.Resolution)
assert.Equal(t, "some global resolution updated", *ggpResp.Policy.Resolution)
assert.Equal(t, "darwin", mgpResp.Policy.Platform)
assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount)
assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount)
policiesResponse := listGlobalPoliciesResponse{}
s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse)
@ -1867,6 +1871,8 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
require.NotNil(t, policiesResponse.Policies[0].Resolution)
assert.Equal(t, "some global resolution updated", *policiesResponse.Policies[0].Resolution)
assert.Equal(t, "darwin", policiesResponse.Policies[0].Platform)
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount)
listHostsURL := fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d", policiesResponse.Policies[0].ID)
listHostsResp := listHostsResponse{}
@ -1880,6 +1886,11 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false))
@ -1888,6 +1899,45 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 1)
mgpParams = modifyGlobalPolicyRequest{
ModifyPolicyPayload: fleet.ModifyPolicyPayload{
Query: ptr.String("select * from users;"),
},
}
mgpResp = modifyGlobalPolicyResponse{}
s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), mgpParams, http.StatusOK, &mgpResp)
require.NotNil(t, gpResp.Policy)
assert.Equal(t, "TestQuery4", mgpResp.Policy.Name)
assert.Equal(t, "select * from users;", mgpResp.Policy.Query)
assert.Equal(t, "Some description updated", mgpResp.Policy.Description)
require.NotNil(t, mgpResp.Policy.Resolution)
assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution)
assert.Equal(t, "darwin", mgpResp.Policy.Platform)
assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount)
assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount)
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
policiesResponse = listGlobalPoliciesResponse{}
s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse)
require.Len(t, policiesResponse.Policies, 1)
assert.Equal(t, "TestQuery4", policiesResponse.Policies[0].Name)
assert.Equal(t, "select * from users;", policiesResponse.Policies[0].Query)
assert.Equal(t, "Some description updated", policiesResponse.Policies[0].Description)
require.NotNil(t, policiesResponse.Policies[0].Resolution)
assert.Equal(t, "some global resolution updated", *policiesResponse.Policies[0].Resolution)
assert.Equal(t, "darwin", policiesResponse.Policies[0].Platform)
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount)
deletePolicyParams := deleteGlobalPoliciesRequest{IDs: []uint{policiesResponse.Policies[0].ID}}
deletePolicyResp := deleteGlobalPoliciesResponse{}
s.DoJSON("POST", "/api/latest/fleet/policies/delete", deletePolicyParams, http.StatusOK, &deletePolicyResp)

View File

@ -366,6 +366,7 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
})
}
var shouldRemoveAll bool
if p.Name != nil {
policy.Name = *p.Name
}
@ -373,6 +374,11 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
policy.Description = *p.Description
}
if p.Query != nil {
if policy.Query != *p.Query {
shouldRemoveAll = true
policy.FailingHostCount = 0
policy.PassingHostCount = 0
}
policy.Query = *p.Query
}
if p.Resolution != nil {
@ -386,7 +392,7 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
}
logging.WithExtras(ctx, "name", policy.Name, "sql", policy.Query)
err = svc.ds.SavePolicy(ctx, policy)
err = svc.ds.SavePolicy(ctx, policy, shouldRemoveAll)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "saving policy")
}

View File

@ -44,7 +44,7 @@ func TestTeamPoliciesAuth(t *testing.T) {
}
return nil, nil
}
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy) error {
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error {
return nil
}
ds.DeleteTeamPoliciesFunc = func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {