allow macOS hosts that turned on MDM via SSO to renew their enrollment (#14739)

for #14238
This commit is contained in:
Roberto Dip 2023-10-27 12:42:30 -03:00 committed by GitHub
parent c570ef1e21
commit 683c1dfe95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 242 additions and 37 deletions

View File

@ -0,0 +1 @@
* Allow macOS hosts that turned on MDM features via SSO to run `sudo profiles renew --type enrollment` to renew their MDM enrollment.

View File

@ -724,16 +724,26 @@ func (svc *Service) mdmSSOHandleCallbackAuth(ctx context.Context, auth fleet.Aut
svc.logger.Log("mdm-sso-callback", "IdP UserID doesn't look like an email, using raw value")
username = auth.UserID()
}
idpAcc := fleet.MDMIdPAccount{
UUID: uuid.New().String(),
err = svc.ds.InsertMDMIdPAccount(ctx, &fleet.MDMIdPAccount{
Username: username,
Fullname: auth.UserDisplayName(),
Email: auth.UserID(),
}
if err := svc.ds.InsertMDMIdPAccount(ctx, &idpAcc); err != nil {
})
if err != nil {
return "", "", "", ctxerr.Wrap(ctx, err, "saving account data from IdP")
}
idpAcc, err := svc.ds.GetMDMIdPAccountByEmail(
// use the primary db as the account might have been just
// inserted
ctxdb.RequirePrimary(ctx, true),
auth.UserID(),
)
if err != nil {
return "", "", "", ctxerr.Wrap(ctx, err, "retrieving new account data from IdP")
}
eula, err := svc.ds.MDMAppleGetEULAMetadata(ctx)
if err != nil && !fleet.IsNotFound(err) {
return "", "", "", ctxerr.Wrap(ctx, err, "getting EULA metadata")

View File

@ -2196,17 +2196,29 @@ func (ds *Datastore) InsertMDMIdPAccount(ctx context.Context, account *fleet.MDM
INSERT INTO mdm_idp_accounts
(uuid, username, fullname, email)
VALUES
(?, ?, ?, ?)
(UUID(), ?, ?, ?)
ON DUPLICATE KEY UPDATE
username = VALUES(username),
fullname = VALUES(fullname),
email = VALUES(email)`
fullname = VALUES(fullname)`
_, err := ds.writer(ctx).ExecContext(ctx, stmt, account.UUID, account.Username, account.Fullname, account.Email)
_, err := ds.writer(ctx).ExecContext(ctx, stmt, account.Username, account.Fullname, account.Email)
return ctxerr.Wrap(ctx, err, "creating new MDM IdP account")
}
func (ds *Datastore) GetMDMIdPAccount(ctx context.Context, uuid string) (*fleet.MDMIdPAccount, error) {
func (ds *Datastore) GetMDMIdPAccountByEmail(ctx context.Context, email string) (*fleet.MDMIdPAccount, error) {
stmt := `SELECT uuid, username, fullname, email FROM mdm_idp_accounts WHERE email = ?`
var acct fleet.MDMIdPAccount
err := sqlx.GetContext(ctx, ds.reader(ctx), &acct, stmt, email)
if err != nil {
if err == sql.ErrNoRows {
return nil, ctxerr.Wrap(ctx, notFound("MDMIdPAccount").WithMessage(fmt.Sprintf("with email %s", email)))
}
return nil, ctxerr.Wrap(ctx, err, "select mdm_idp_accounts by email")
}
return &acct, nil
}
func (ds *Datastore) GetMDMIdPAccountByUUID(ctx context.Context, uuid string) (*fleet.MDMIdPAccount, error) {
stmt := `SELECT uuid, username, fullname, email FROM mdm_idp_accounts WHERE uuid = ?`
var acct fleet.MDMIdPAccount
err := sqlx.GetContext(ctx, ds.reader(ctx), &acct, stmt, uuid)

View File

@ -2074,8 +2074,8 @@ func testMDMAppleHostsProfilesStatus(t *testing.T, ds *Datastore) {
func testMDMAppleIdPAccount(t *testing.T, ds *Datastore) {
ctx := context.Background()
acc := &fleet.MDMIdPAccount{
UUID: "ABC-DEF",
Username: "email@example.com",
Email: "email@example.com",
Fullname: "John Doe",
}
@ -2086,14 +2086,25 @@ func testMDMAppleIdPAccount(t *testing.T, ds *Datastore) {
err = ds.InsertMDMIdPAccount(ctx, acc)
require.NoError(t, err)
out, err := ds.GetMDMIdPAccount(ctx, acc.UUID)
out, err := ds.GetMDMIdPAccountByEmail(ctx, acc.Email)
require.NoError(t, err)
// update the acc UUID
acc.UUID = out.UUID
require.Equal(t, acc, out)
var nfe fleet.NotFoundError
out, err = ds.GetMDMIdPAccount(ctx, "BAD-TOKEN")
out, err = ds.GetMDMIdPAccountByEmail(ctx, "bad@email.com")
require.ErrorAs(t, err, &nfe)
require.Nil(t, out)
out, err = ds.GetMDMIdPAccountByUUID(ctx, acc.UUID)
require.NoError(t, err)
require.Equal(t, acc, out)
out, err = ds.GetMDMIdPAccountByUUID(ctx, "BAD-TOKEN")
require.ErrorAs(t, err, &nfe)
require.Nil(t, out)
}
func testIgnoreMDMClientError(t *testing.T, ds *Datastore) {

View File

@ -0,0 +1,55 @@
package tables
import (
"database/sql"
"fmt"
)
func init() {
MigrationClient.AddMigration(Up_20231025120016, Down_20231025120016)
}
func Up_20231025120016(tx *sql.Tx) error {
deleteDuplicatesStmt := `
DELETE a
FROM mdm_idp_accounts a
LEFT JOIN (
-- MAX(uuid) is completely arbitrary as it'll compare the UUIDs
-- lexicographically, but we don't have a better field to compare the rows.
SELECT email, MAX(uuid) as latest_uuid
FROM mdm_idp_accounts
GROUP BY email
) b ON a.email = b.email AND a.uuid = b.latest_uuid
WHERE b.latest_uuid IS NULL;
`
addIdxStmt := `
ALTER TABLE mdm_idp_accounts
ADD UNIQUE KEY unique_idp_email (email)`
addTimestampsStmt := `
ALTER TABLE mdm_idp_accounts
ADD COLUMN created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
ADD COLUMN updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP`
// delete duplicates
if _, err := tx.Exec(deleteDuplicatesStmt); err != nil {
return fmt.Errorf("failed to delete duplicated emails in mdm_idp_accounts table: %w", err)
}
// add an index to prevent further duplicates
if _, err := tx.Exec(addIdxStmt); err != nil {
return fmt.Errorf("failed to delete duplicated emails in mdm_idp_accounts table: %w", err)
}
// add missing timestamps
if _, err := tx.Exec(addTimestampsStmt); err != nil {
return fmt.Errorf("failed to delete duplicated emails in mdm_idp_accounts table: %w", err)
}
return nil
}
func Down_20231025120016(tx *sql.Tx) error {
return nil
}

View File

@ -0,0 +1,57 @@
package tables
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestUp_20231025120016(t *testing.T) {
db := applyUpToPrev(t)
type idpAcc struct {
Email string `db:"email"`
UUID string `db:"uuid"`
Username string `db:"username"`
Fullname string `db:"fullname"`
}
insertStmt := `INSERT INTO mdm_idp_accounts (email, uuid, username, fullname) VALUES (?, ?, ?, ?)`
loadAccountsStmt := `
SELECT email, uuid, username, fullname
FROM mdm_idp_accounts ORDER BY uuid
`
rowsToInsert := []idpAcc{
{Email: "foo@example.com", UUID: "UUID1", Username: "foo", Fullname: "Foo"},
{Email: "foo@example.com", UUID: "UUID2", Username: "foo", Fullname: "Foo"},
{Email: "bar@example.com", UUID: "UUID3", Username: "bar", Fullname: "Bar"},
{Email: "baz@example.com", UUID: "UUID4", Username: "baz", Fullname: "Baz"},
{Email: "baz@example.com", UUID: "UUID5", Username: "baz", Fullname: "Baz"},
}
for _, r := range rowsToInsert {
_, err := db.Exec(insertStmt, r.Email, r.UUID, r.Username, r.Fullname)
require.NoError(t, err)
}
var results []idpAcc
err := db.Select(&results, loadAccountsStmt)
require.NoError(t, err)
require.Len(t, results, 5)
require.Equal(t, results, rowsToInsert)
// Apply current migration.
applyNext(t, db)
// check that duplicates are gone
results = []idpAcc{}
err = db.Select(&results, loadAccountsStmt)
require.NoError(t, err)
require.Len(t, results, 3)
require.Equal(t, results, []idpAcc{
{Email: "foo@example.com", UUID: "UUID2", Username: "foo", Fullname: "Foo"},
{Email: "bar@example.com", UUID: "UUID3", Username: "bar", Fullname: "Bar"},
{Email: "baz@example.com", UUID: "UUID5", Username: "baz", Fullname: "Baz"},
})
}

File diff suppressed because one or more lines are too long

View File

@ -952,8 +952,11 @@ type Datastore interface {
// InsertMDMIdPAccount inserts a new MDM IdP account
InsertMDMIdPAccount(ctx context.Context, account *MDMIdPAccount) error
// GetMDMIdPAccount returns MDM IdP account that matches the given token.
GetMDMIdPAccount(ctx context.Context, uuid string) (*MDMIdPAccount, error)
// GetMDMIdPAccountByUUID returns MDM IdP account that matches the given token.
GetMDMIdPAccountByUUID(ctx context.Context, uuid string) (*MDMIdPAccount, error)
// GetMDMIdPAccountByEmail returns MDM IdP account that matches the given email.
GetMDMIdPAccountByEmail(ctx context.Context, email string) (*MDMIdPAccount, error)
// GetMDMAppleFileVaultSummary summarizes the current state of Apple disk encryption profiles on
// each macOS host in the specified team (or, if no team is specified, each host that is not assigned

View File

@ -628,7 +628,9 @@ type GetMDMAppleHostsProfilesSummaryFunc func(ctx context.Context, teamID *uint)
type InsertMDMIdPAccountFunc func(ctx context.Context, account *fleet.MDMIdPAccount) error
type GetMDMIdPAccountFunc func(ctx context.Context, uuid string) (*fleet.MDMIdPAccount, error)
type GetMDMIdPAccountByUUIDFunc func(ctx context.Context, uuid string) (*fleet.MDMIdPAccount, error)
type GetMDMIdPAccountByEmailFunc func(ctx context.Context, email string) (*fleet.MDMIdPAccount, error)
type GetMDMAppleFileVaultSummaryFunc func(ctx context.Context, teamID *uint) (*fleet.MDMAppleFileVaultSummary, error)
@ -1630,8 +1632,11 @@ type DataStore struct {
InsertMDMIdPAccountFunc InsertMDMIdPAccountFunc
InsertMDMIdPAccountFuncInvoked bool
GetMDMIdPAccountFunc GetMDMIdPAccountFunc
GetMDMIdPAccountFuncInvoked bool
GetMDMIdPAccountByUUIDFunc GetMDMIdPAccountByUUIDFunc
GetMDMIdPAccountByUUIDFuncInvoked bool
GetMDMIdPAccountByEmailFunc GetMDMIdPAccountByEmailFunc
GetMDMIdPAccountByEmailFuncInvoked bool
GetMDMAppleFileVaultSummaryFunc GetMDMAppleFileVaultSummaryFunc
GetMDMAppleFileVaultSummaryFuncInvoked bool
@ -3897,11 +3902,18 @@ func (s *DataStore) InsertMDMIdPAccount(ctx context.Context, account *fleet.MDMI
return s.InsertMDMIdPAccountFunc(ctx, account)
}
func (s *DataStore) GetMDMIdPAccount(ctx context.Context, uuid string) (*fleet.MDMIdPAccount, error) {
func (s *DataStore) GetMDMIdPAccountByUUID(ctx context.Context, uuid string) (*fleet.MDMIdPAccount, error) {
s.mu.Lock()
s.GetMDMIdPAccountFuncInvoked = true
s.GetMDMIdPAccountByUUIDFuncInvoked = true
s.mu.Unlock()
return s.GetMDMIdPAccountFunc(ctx, uuid)
return s.GetMDMIdPAccountByUUIDFunc(ctx, uuid)
}
func (s *DataStore) GetMDMIdPAccountByEmail(ctx context.Context, email string) (*fleet.MDMIdPAccount, error) {
s.mu.Lock()
s.GetMDMIdPAccountByEmailFuncInvoked = true
s.mu.Unlock()
return s.GetMDMIdPAccountByEmailFunc(ctx, email)
}
func (s *DataStore) GetMDMAppleFileVaultSummary(ctx context.Context, teamID *uint) (*fleet.MDMAppleFileVaultSummary, error) {

View File

@ -6111,12 +6111,12 @@ func (s *integrationMDMTestSuite) TestSSO() {
require.Contains(t, lastSubmittedProfile.URL, acResp.ServerSettings.ServerURL+"/api/mdm/apple/enroll?token=")
require.Equal(t, acResp.ServerSettings.ServerURL+"/mdm/sso", lastSubmittedProfile.ConfigurationWebURL)
checkStoredIdPInfo := func(uuid string) {
acc, err := s.ds.GetMDMIdPAccount(context.Background(), uuid)
checkStoredIdPInfo := func(uuid, username, fullname, email string) {
acc, err := s.ds.GetMDMIdPAccountByUUID(context.Background(), uuid)
require.NoError(t, err)
require.Equal(t, "sso_user", acc.Username)
require.Equal(t, "SSO User 1", acc.Fullname)
require.Equal(t, "sso_user@example.com", acc.Email)
require.Equal(t, username, acc.Username)
require.Equal(t, fullname, acc.Fullname)
require.Equal(t, email, acc.Email)
}
res := s.LoginMDMSSOUser("sso_user", "user123#")
@ -6126,6 +6126,7 @@ func (s *integrationMDMTestSuite) TestSSO() {
u, err := url.Parse(res.Header.Get("Location"))
require.NoError(t, err)
q := u.Query()
user1EnrollRef := q.Get("enrollment_reference")
// without an EULA uploaded
require.False(t, q.Has("eula_token"))
require.True(t, q.Has("profile_token"))
@ -6136,12 +6137,12 @@ func (s *integrationMDMTestSuite) TestSSO() {
fmt.Sprintf(
"/api/mdm/apple/enroll?token=%s&enrollment_reference=%s",
q.Get("profile_token"),
q.Get("enrollment_reference"),
user1EnrollRef,
),
)
// IdP info stored is accurate for the account
checkStoredIdPInfo(q.Get("enrollment_reference"))
checkStoredIdPInfo(user1EnrollRef, "sso_user", "SSO User 1", "sso_user@example.com")
// upload an EULA
pdfBytes := []byte("%PDF-1.pdf-contents")
@ -6159,12 +6160,14 @@ func (s *integrationMDMTestSuite) TestSSO() {
require.True(t, q.Has("profile_token"))
require.True(t, q.Has("enrollment_reference"))
require.False(t, q.Has("error"))
// the enrollment reference is the same for the same user
require.Equal(t, user1EnrollRef, q.Get("enrollment_reference"))
// the url retrieves a valid profile
prof := s.downloadAndVerifyEnrollmentProfile(
fmt.Sprintf(
"/api/mdm/apple/enroll?token=%s&enrollment_reference=%s",
q.Get("profile_token"),
q.Get("enrollment_reference"),
user1EnrollRef,
),
)
// the url retrieves a valid EULA
@ -6176,7 +6179,7 @@ func (s *integrationMDMTestSuite) TestSSO() {
require.EqualValues(t, pdfBytes, respBytes)
// IdP info stored is accurate for the account
checkStoredIdPInfo(q.Get("enrollment_reference"))
checkStoredIdPInfo(user1EnrollRef, "sso_user", "SSO User 1", "sso_user@example.com")
enrollURL := ""
scepURL := ""
@ -6221,6 +6224,39 @@ func (s *integrationMDMTestSuite) TestSSO() {
require.Equal(t, "SSO User 1", accCmd.Command.AccountConfiguration.PrimaryAccountFullName)
require.Equal(t, "sso_user", accCmd.Command.AccountConfiguration.PrimaryAccountUserName)
// enrolling a different user works without problems
res = s.LoginMDMSSOUser("sso_user2", "user123#")
require.NotEmpty(t, res.Header.Get("Location"))
require.Equal(t, http.StatusTemporaryRedirect, res.StatusCode)
u, err = url.Parse(res.Header.Get("Location"))
require.NoError(t, err)
q = u.Query()
user2EnrollRef := q.Get("enrollment_reference")
require.True(t, q.Has("eula_token"))
require.True(t, q.Has("profile_token"))
require.True(t, q.Has("enrollment_reference"))
require.False(t, q.Has("error"))
// the enrollment reference is different to the one used for the previous user
require.NotEqual(t, user1EnrollRef, user2EnrollRef)
// the url retrieves a valid profile
s.downloadAndVerifyEnrollmentProfile(
fmt.Sprintf(
"/api/mdm/apple/enroll?token=%s&enrollment_reference=%s",
q.Get("profile_token"),
user2EnrollRef,
),
)
// the url retrieves a valid EULA
resp = s.DoRaw("GET", "/api/latest/fleet/mdm/apple/setup/eula/"+q.Get("eula_token"), nil, http.StatusOK)
require.EqualValues(t, len(pdfBytes), resp.ContentLength)
require.Equal(t, "application/pdf", resp.Header.Get("content-type"))
respBytes, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.EqualValues(t, pdfBytes, respBytes)
// IdP info stored is accurate for the account
checkStoredIdPInfo(user2EnrollRef, "sso_user2", "SSO User 2", "sso_user2@example.com")
// changing the server URL also updates the remote DEP profile
acResp = appConfigResponse{}
s.DoJSON("PATCH", "/api/latest/fleet/config", json.RawMessage(`{

View File

@ -79,7 +79,7 @@ func (a *AppleMDM) runPostDEPEnrollment(ctx context.Context, args appleMDMArgs)
return ctxerr.Wrap(ctx, err, "getting app config")
}
acct, err := a.Datastore.GetMDMIdPAccount(ctx, ref)
acct, err := a.Datastore.GetMDMIdPAccountByUUID(ctx, ref)
if err != nil {
return ctxerr.Wrapf(ctx, err, "getting idp account details for enroll reference %s", ref)
}

View File

@ -304,12 +304,15 @@ func TestAppleMDM(t *testing.T) {
defer mysql.TruncateTables(t, ds)
err := ds.InsertMDMIdPAccount(ctx, &fleet.MDMIdPAccount{
UUID: "abcd",
Username: "test",
Fullname: "test",
Email: "test@example.com",
})
require.NoError(t, err)
idpAcc, err := ds.GetMDMIdPAccountByEmail(ctx, "test@example.com")
require.NoError(t, err)
h := createEnrolledHost(t, 1, nil, true)
mdmWorker := &AppleMDM{
@ -320,7 +323,7 @@ func TestAppleMDM(t *testing.T) {
w := NewWorker(ds, nopLog)
w.Register(mdmWorker)
err = QueueAppleMDMJob(ctx, ds, nopLog, AppleMDMPostDEPEnrollmentTask, h.UUID, nil, "abcd")
err = QueueAppleMDMJob(ctx, ds, nopLog, AppleMDMPostDEPEnrollmentTask, h.UUID, nil, idpAcc.UUID)
require.NoError(t, err)
// run the worker, should succeed
@ -342,13 +345,15 @@ func TestAppleMDM(t *testing.T) {
defer mysql.TruncateTables(t, ds)
err := ds.InsertMDMIdPAccount(ctx, &fleet.MDMIdPAccount{
UUID: "abcd",
Username: "test",
Fullname: "test",
Email: "test@example.com",
})
require.NoError(t, err)
idpAcc, err := ds.GetMDMIdPAccountByEmail(ctx, "test@example.com")
require.NoError(t, err)
tm, err := ds.NewTeam(ctx, &fleet.Team{Name: "test"})
require.NoError(t, err)
tm, err = ds.Team(ctx, tm.ID)
@ -367,7 +372,7 @@ func TestAppleMDM(t *testing.T) {
w := NewWorker(ds, nopLog)
w.Register(mdmWorker)
err = QueueAppleMDMJob(ctx, ds, nopLog, AppleMDMPostDEPEnrollmentTask, h.UUID, &tm.ID, "abcd")
err = QueueAppleMDMJob(ctx, ds, nopLog, AppleMDMPostDEPEnrollmentTask, h.UUID, &tm.ID, idpAcc.UUID)
require.NoError(t, err)
// run the worker, should succeed