ensure previous host membership policies are deleted on enrollment (#5890)

Related to #5776, this ensures that when a host is re-enrolled on a different team we cleanup existing policy memberships in the same way we do when a host is assigned a team through the API.
This commit is contained in:
Roberto Dip 2022-05-30 10:30:15 -03:00 committed by GitHub
parent 9798e6abb4
commit fc7a045209
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 193 additions and 15 deletions

View File

@ -0,0 +1 @@
Fixed data inconsistencies in policy counts that happened when a host was re-enrolled without a team or in a different one.

View File

@ -606,6 +606,21 @@ func (ds *Datastore) GenerateHostStatusStatistics(ctx context.Context, filter fl
return &summary, nil
}
func shouldCleanTeamPolicies(currentTeamID, newTeamID *uint) bool {
// if the host is global, then there should be nothing to clean up
if currentTeamID == nil {
return false
}
// if the host is switching from a team to global, we should clean up
if newTeamID == nil {
return true
}
// clean up if the host is switching to a different team
return *currentTeamID != *newTeamID
}
// EnrollHost enrolls a host
func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey string, teamID *uint, cooldown time.Duration) (*fleet.Host, error) {
if osqueryHostID == "" {
@ -617,7 +632,7 @@ func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey stri
zeroTime := time.Unix(0, 0).Add(24 * time.Hour)
var hostID int64
err := sqlx.GetContext(ctx, tx, &host, `SELECT id, last_enrolled_at FROM hosts WHERE osquery_host_id = ?`, osqueryHostID)
err := sqlx.GetContext(ctx, tx, &host, `SELECT id, last_enrolled_at, team_id FROM hosts WHERE osquery_host_id = ?`, osqueryHostID)
switch {
case err != nil && !errors.Is(err, sql.ErrNoRows):
return ctxerr.Wrap(ctx, err, "check existing")
@ -646,6 +661,13 @@ func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey stri
return backoff.Permanent(ctxerr.Errorf(ctx, "host identified by %s enrolling too often", osqueryHostID))
}
hostID = int64(host.ID)
if shouldCleanTeamPolicies(host.TeamID, teamID) {
if err := cleanupPolicyMembershipOnTeamChange(ctx, tx, []uint{host.ID}); err != nil {
return ctxerr.Wrap(ctx, err, "EnrollHost delete policy membership")
}
}
// Update existing host record
sqlUpdate := `
UPDATE hosts
@ -874,18 +896,11 @@ func (ds *Datastore) AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs [
}
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// hosts can only be in one team, so if there's a policy that has a team id and a result from one of our hosts
// it can only be from the previous team they are being transferred from
query, args, err := sqlx.In(`DELETE FROM policy_membership
WHERE policy_id IN (SELECT id FROM policies WHERE team_id IS NOT NULL) AND host_id IN (?)`, hostIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "add host to team sqlx in")
}
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return ctxerr.Wrap(ctx, err, "exec AddHostsToTeam delete policy membership")
if err := cleanupPolicyMembershipOnTeamChange(ctx, tx, hostIDs); err != nil {
return ctxerr.Wrap(ctx, err, "AddHostsToTeam delete policy membership")
}
query, args, err = sqlx.In(`UPDATE hosts SET team_id = ? WHERE id IN (?)`, teamID, hostIDs)
query, args, err := sqlx.In(`UPDATE hosts SET team_id = ? WHERE id IN (?)`, teamID, hostIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "sqlx.In AddHostsToTeam")
}

View File

@ -116,6 +116,7 @@ func TestHosts(t *testing.T) {
{"SetOrUpdateDeviceAuthToken", testHostsSetOrUpdateDeviceAuthToken},
{"OSVersions", testOSVersions},
{"DeleteHosts", testHostsDeleteHosts},
{"ShouldCleanTeamPolicies", testShouldCleanTeamPolicies},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -4184,3 +4185,24 @@ func testHostsDeleteHosts(t *testing.T, ds *Datastore) {
require.False(t, ok, "table: %s", hostRef)
}
}
func testShouldCleanTeamPolicies(t *testing.T, ds *Datastore) {
var idOne uint = 1
var idTwo uint = 2
cases := []struct {
currentTeamID *uint
newTeamID *uint
out bool
}{
{nil, nil, false},
{nil, &idOne, false},
{&idOne, nil, true},
{&idOne, &idOne, false},
{&idOne, &idTwo, true},
}
for _, c := range cases {
require.Equal(t, shouldCleanTeamPolicies(c.currentTeamID, c.newTeamID), c.out)
}
}

View File

@ -0,0 +1,32 @@
package tables
import (
"database/sql"
"fmt"
)
func init() {
MigrationClient.AddMigration(Up_20220524102918, Down_20220524102918)
}
func Up_20220524102918(tx *sql.Tx) error {
delStmt := `
DELETE pm
FROM policy_membership pm
LEFT JOIN policies p ON p.id = pm.policy_id
LEFT JOIN hosts h ON h.id = pm.host_id
WHERE p.team_id IS NOT NULL
AND (p.team_id != h.team_id
OR h.team_id IS NULL)
`
if _, err := tx.Exec(delStmt); err != nil {
return fmt.Errorf("deleting orphaned policy memberships: %w", err)
}
return nil
}
func Down_20220524102918(tx *sql.Tx) error {
return nil
}

View File

@ -0,0 +1,73 @@
package tables
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestUp_20220524102918(t *testing.T) {
db := applyUpToPrev(t)
res, err := db.Exec(`
INSERT INTO teams (name)
VALUES ('test_team')
`)
require.NoError(t, err)
teamID, err := res.LastInsertId()
require.NoError(t, err)
_, err = db.Exec(`
INSERT INTO policies (name, query, description, team_id)
VALUES ('test_policy', "", "", ?)
`, teamID)
require.NoError(t, err)
policyID, err := res.LastInsertId()
require.NoError(t, err)
res, err = db.Exec(`
INSERT INTO hosts (osquery_host_id, team_id)
VALUES (1, ?)
`, teamID)
require.NoError(t, err)
host1ID, err := res.LastInsertId()
require.NoError(t, err)
res, err = db.Exec(`
INSERT INTO hosts (osquery_host_id, team_id)
VALUES (2, ?)
`, nil)
require.NoError(t, err)
host2ID, err := res.LastInsertId()
require.NoError(t, err)
_, err = db.Exec(`
INSERT INTO policy_membership (host_id, policy_id)
VALUES (?, ?)
`, host1ID, policyID)
require.NoError(t, err)
_, err = db.Exec(`
INSERT INTO policy_membership (host_id, policy_id)
VALUES (?, ?)
`, host2ID, policyID)
require.NoError(t, err)
var count int
const countStmt = `SELECT COUNT(*) FROM policy_membership`
err = db.Get(&count, countStmt)
require.NoError(t, err)
require.Equal(t, 2, count)
// Apply current migration.
applyNext(t, db)
err = db.Get(&count, countStmt)
require.NoError(t, err)
require.Equal(t, 1, count)
var id int64
err = db.Get(&id, `SELECT host_id FROM policy_membership`)
require.NoError(t, err)
require.Equal(t, id, host1ID)
}

View File

@ -535,6 +535,20 @@ func (ds *Datastore) AsyncBatchUpdatePolicyTimestamp(ctx context.Context, ids []
})
}
func cleanupPolicyMembershipOnTeamChange(ctx context.Context, tx sqlx.ExtContext, hostIDs []uint) error {
// hosts can only be in one team, so if there's a policy that has a team id and a result from one of our hosts
// it can only be from the previous team they are being transferred from
query, args, err := sqlx.In(`DELETE FROM policy_membership
WHERE policy_id IN (SELECT id FROM policies WHERE team_id IS NOT NULL) AND host_id IN (?)`, hostIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "clean old policy memberships sqlx in")
}
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return ctxerr.Wrap(ctx, err, "exec clean old policy memberships")
}
return nil
}
func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerContext, policyID uint, platforms string) error {
if platforms == "" {
// all platforms allowed, nothing to clean up

View File

@ -940,6 +940,8 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
Hostname: "foo.local",
})
require.NoError(t, err)
host2, err := ds.EnrollHost(context.Background(), "2", "2", &team1.ID, 0)
require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID}))
host1, err = ds.Host(context.Background(), host1.ID)
@ -971,6 +973,8 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
checkPassingCount := func(expectedCount uint) {
policies, err := ds.ListTeamPolicies(context.Background(), team1.ID)
@ -982,17 +986,34 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
policies, err = ds.ListGlobalPolicies(context.Background())
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, uint(1), policies[0].PassingHostCount)
assert.Equal(t, uint(2), policies[0].PassingHostCount)
policies, err = ds.ListTeamPolicies(context.Background(), team2.ID)
require.NoError(t, err)
require.Len(t, policies, 0)
}
checkPassingCount(2)
// team policies are removed when AddHostsToTeam is called
require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID}))
checkPassingCount(1)
require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID}))
// team policies are not removed when a host is enrolled in the same team
_, err = ds.EnrollHost(context.Background(), "2", "2", &team1.ID, 0)
require.NoError(t, err)
checkPassingCount(1)
// team policies are removed if the host is enrolled in a different team
_, err = ds.EnrollHost(context.Background(), "2", "2", &team2.ID, 0)
require.NoError(t, err)
checkPassingCount(0)
// team policies are removed if the host is re-enrolled without a team
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
checkPassingCount(1)
_, err = ds.EnrollHost(context.Background(), "2", "2", nil, 0)
require.NoError(t, err)
checkPassingCount(0)
}

File diff suppressed because one or more lines are too long