mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
ensure previous host membership policies are deleted on enrollment (#5890)
Related to #5776, this ensures that when a host is re-enrolled on a different team we cleanup existing policy memberships in the same way we do when a host is assigned a team through the API.
This commit is contained in:
parent
9798e6abb4
commit
fc7a045209
1
changes/issue-5776-fix-orphaned-policies
Normal file
1
changes/issue-5776-fix-orphaned-policies
Normal file
@ -0,0 +1 @@
|
||||
Fixed data inconsistencies in policy counts that happened when a host was re-enrolled without a team or in a different one.
|
@ -606,6 +606,21 @@ func (ds *Datastore) GenerateHostStatusStatistics(ctx context.Context, filter fl
|
||||
return &summary, nil
|
||||
}
|
||||
|
||||
func shouldCleanTeamPolicies(currentTeamID, newTeamID *uint) bool {
|
||||
// if the host is global, then there should be nothing to clean up
|
||||
if currentTeamID == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// if the host is switching from a team to global, we should clean up
|
||||
if newTeamID == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// clean up if the host is switching to a different team
|
||||
return *currentTeamID != *newTeamID
|
||||
}
|
||||
|
||||
// EnrollHost enrolls a host
|
||||
func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey string, teamID *uint, cooldown time.Duration) (*fleet.Host, error) {
|
||||
if osqueryHostID == "" {
|
||||
@ -617,7 +632,7 @@ func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey stri
|
||||
zeroTime := time.Unix(0, 0).Add(24 * time.Hour)
|
||||
|
||||
var hostID int64
|
||||
err := sqlx.GetContext(ctx, tx, &host, `SELECT id, last_enrolled_at FROM hosts WHERE osquery_host_id = ?`, osqueryHostID)
|
||||
err := sqlx.GetContext(ctx, tx, &host, `SELECT id, last_enrolled_at, team_id FROM hosts WHERE osquery_host_id = ?`, osqueryHostID)
|
||||
switch {
|
||||
case err != nil && !errors.Is(err, sql.ErrNoRows):
|
||||
return ctxerr.Wrap(ctx, err, "check existing")
|
||||
@ -646,6 +661,13 @@ func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey stri
|
||||
return backoff.Permanent(ctxerr.Errorf(ctx, "host identified by %s enrolling too often", osqueryHostID))
|
||||
}
|
||||
hostID = int64(host.ID)
|
||||
|
||||
if shouldCleanTeamPolicies(host.TeamID, teamID) {
|
||||
if err := cleanupPolicyMembershipOnTeamChange(ctx, tx, []uint{host.ID}); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "EnrollHost delete policy membership")
|
||||
}
|
||||
}
|
||||
|
||||
// Update existing host record
|
||||
sqlUpdate := `
|
||||
UPDATE hosts
|
||||
@ -874,18 +896,11 @@ func (ds *Datastore) AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs [
|
||||
}
|
||||
|
||||
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
// hosts can only be in one team, so if there's a policy that has a team id and a result from one of our hosts
|
||||
// it can only be from the previous team they are being transferred from
|
||||
query, args, err := sqlx.In(`DELETE FROM policy_membership
|
||||
WHERE policy_id IN (SELECT id FROM policies WHERE team_id IS NOT NULL) AND host_id IN (?)`, hostIDs)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "add host to team sqlx in")
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "exec AddHostsToTeam delete policy membership")
|
||||
if err := cleanupPolicyMembershipOnTeamChange(ctx, tx, hostIDs); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "AddHostsToTeam delete policy membership")
|
||||
}
|
||||
|
||||
query, args, err = sqlx.In(`UPDATE hosts SET team_id = ? WHERE id IN (?)`, teamID, hostIDs)
|
||||
query, args, err := sqlx.In(`UPDATE hosts SET team_id = ? WHERE id IN (?)`, teamID, hostIDs)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "sqlx.In AddHostsToTeam")
|
||||
}
|
||||
|
@ -116,6 +116,7 @@ func TestHosts(t *testing.T) {
|
||||
{"SetOrUpdateDeviceAuthToken", testHostsSetOrUpdateDeviceAuthToken},
|
||||
{"OSVersions", testOSVersions},
|
||||
{"DeleteHosts", testHostsDeleteHosts},
|
||||
{"ShouldCleanTeamPolicies", testShouldCleanTeamPolicies},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
@ -4184,3 +4185,24 @@ func testHostsDeleteHosts(t *testing.T, ds *Datastore) {
|
||||
require.False(t, ok, "table: %s", hostRef)
|
||||
}
|
||||
}
|
||||
|
||||
func testShouldCleanTeamPolicies(t *testing.T, ds *Datastore) {
|
||||
var idOne uint = 1
|
||||
var idTwo uint = 2
|
||||
|
||||
cases := []struct {
|
||||
currentTeamID *uint
|
||||
newTeamID *uint
|
||||
out bool
|
||||
}{
|
||||
{nil, nil, false},
|
||||
{nil, &idOne, false},
|
||||
{&idOne, nil, true},
|
||||
{&idOne, &idOne, false},
|
||||
{&idOne, &idTwo, true},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
require.Equal(t, shouldCleanTeamPolicies(c.currentTeamID, c.newTeamID), c.out)
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,32 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
MigrationClient.AddMigration(Up_20220524102918, Down_20220524102918)
|
||||
}
|
||||
|
||||
func Up_20220524102918(tx *sql.Tx) error {
|
||||
delStmt := `
|
||||
DELETE pm
|
||||
FROM policy_membership pm
|
||||
LEFT JOIN policies p ON p.id = pm.policy_id
|
||||
LEFT JOIN hosts h ON h.id = pm.host_id
|
||||
WHERE p.team_id IS NOT NULL
|
||||
AND (p.team_id != h.team_id
|
||||
OR h.team_id IS NULL)
|
||||
`
|
||||
|
||||
if _, err := tx.Exec(delStmt); err != nil {
|
||||
return fmt.Errorf("deleting orphaned policy memberships: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Down_20220524102918(tx *sql.Tx) error {
|
||||
return nil
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
package tables
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUp_20220524102918(t *testing.T) {
|
||||
db := applyUpToPrev(t)
|
||||
|
||||
res, err := db.Exec(`
|
||||
INSERT INTO teams (name)
|
||||
VALUES ('test_team')
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
teamID, err := res.LastInsertId()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO policies (name, query, description, team_id)
|
||||
VALUES ('test_policy', "", "", ?)
|
||||
`, teamID)
|
||||
require.NoError(t, err)
|
||||
policyID, err := res.LastInsertId()
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err = db.Exec(`
|
||||
INSERT INTO hosts (osquery_host_id, team_id)
|
||||
VALUES (1, ?)
|
||||
`, teamID)
|
||||
require.NoError(t, err)
|
||||
host1ID, err := res.LastInsertId()
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err = db.Exec(`
|
||||
INSERT INTO hosts (osquery_host_id, team_id)
|
||||
VALUES (2, ?)
|
||||
`, nil)
|
||||
require.NoError(t, err)
|
||||
host2ID, err := res.LastInsertId()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO policy_membership (host_id, policy_id)
|
||||
VALUES (?, ?)
|
||||
`, host1ID, policyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO policy_membership (host_id, policy_id)
|
||||
VALUES (?, ?)
|
||||
`, host2ID, policyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int
|
||||
const countStmt = `SELECT COUNT(*) FROM policy_membership`
|
||||
err = db.Get(&count, countStmt)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, count)
|
||||
|
||||
// Apply current migration.
|
||||
applyNext(t, db)
|
||||
|
||||
err = db.Get(&count, countStmt)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
var id int64
|
||||
err = db.Get(&id, `SELECT host_id FROM policy_membership`)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, host1ID)
|
||||
}
|
@ -535,6 +535,20 @@ func (ds *Datastore) AsyncBatchUpdatePolicyTimestamp(ctx context.Context, ids []
|
||||
})
|
||||
}
|
||||
|
||||
func cleanupPolicyMembershipOnTeamChange(ctx context.Context, tx sqlx.ExtContext, hostIDs []uint) error {
|
||||
// hosts can only be in one team, so if there's a policy that has a team id and a result from one of our hosts
|
||||
// it can only be from the previous team they are being transferred from
|
||||
query, args, err := sqlx.In(`DELETE FROM policy_membership
|
||||
WHERE policy_id IN (SELECT id FROM policies WHERE team_id IS NOT NULL) AND host_id IN (?)`, hostIDs)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "clean old policy memberships sqlx in")
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "exec clean old policy memberships")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerContext, policyID uint, platforms string) error {
|
||||
if platforms == "" {
|
||||
// all platforms allowed, nothing to clean up
|
||||
|
@ -940,6 +940,8 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
|
||||
Hostname: "foo.local",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
host2, err := ds.EnrollHost(context.Background(), "2", "2", &team1.ID, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID}))
|
||||
host1, err = ds.Host(context.Background(), host1.ID)
|
||||
@ -971,6 +973,8 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
|
||||
checkPassingCount := func(expectedCount uint) {
|
||||
policies, err := ds.ListTeamPolicies(context.Background(), team1.ID)
|
||||
@ -982,17 +986,34 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
|
||||
policies, err = ds.ListGlobalPolicies(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
assert.Equal(t, uint(1), policies[0].PassingHostCount)
|
||||
assert.Equal(t, uint(2), policies[0].PassingHostCount)
|
||||
|
||||
policies, err = ds.ListTeamPolicies(context.Background(), team2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 0)
|
||||
}
|
||||
|
||||
checkPassingCount(2)
|
||||
|
||||
// team policies are removed when AddHostsToTeam is called
|
||||
require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID}))
|
||||
checkPassingCount(1)
|
||||
|
||||
require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID}))
|
||||
// team policies are not removed when a host is enrolled in the same team
|
||||
_, err = ds.EnrollHost(context.Background(), "2", "2", &team1.ID, 0)
|
||||
require.NoError(t, err)
|
||||
checkPassingCount(1)
|
||||
|
||||
// team policies are removed if the host is enrolled in a different team
|
||||
_, err = ds.EnrollHost(context.Background(), "2", "2", &team2.ID, 0)
|
||||
require.NoError(t, err)
|
||||
checkPassingCount(0)
|
||||
|
||||
// team policies are removed if the host is re-enrolled without a team
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
|
||||
checkPassingCount(1)
|
||||
_, err = ds.EnrollHost(context.Background(), "2", "2", nil, 0)
|
||||
require.NoError(t, err)
|
||||
checkPassingCount(0)
|
||||
}
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user