fix issues with migration flow (#13297)

For #13094
This commit is contained in:
Roberto Dip 2023-08-14 09:56:59 -03:00 committed by GitHub
parent 700fdd3999
commit 902e064d04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 273 additions and 56 deletions

View File

@ -0,0 +1 @@
* Fixed issues with end user migration flow.

View File

@ -10,10 +10,12 @@ import (
"time"
"github.com/fleetdm/fleet/v4/orbit/pkg/constant"
"github.com/fleetdm/fleet/v4/orbit/pkg/profiles"
"github.com/fleetdm/fleet/v4/orbit/pkg/token"
"github.com/fleetdm/fleet/v4/orbit/pkg/update"
"github.com/fleetdm/fleet/v4/orbit/pkg/useraction"
"github.com/fleetdm/fleet/v4/pkg/certificate"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/pkg/open"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/service"
@ -237,7 +239,6 @@ func main() {
)
mdmMigrator = useraction.NewMDMMigrator(
swiftDialogPath,
fleetURL,
15*time.Minute,
&mdmMigrationHandler{
client: client,
@ -297,28 +298,51 @@ func main() {
}
myDeviceItem.Enable()
if runtime.GOOS == "darwin" &&
(sum.Notifications.NeedsMDMMigration || sum.Notifications.RenewEnrollmentProfile) &&
mdmMigrator.CanRun() {
shouldRunMigrator := sum.Notifications.NeedsMDMMigration || sum.Notifications.RenewEnrollmentProfile
// update org info in case it changed
mdmMigrator.SetProps(useraction.MDMMigratorProps{
OrgInfo: sum.Config.OrgInfo,
IsUnmanaged: sum.Notifications.RenewEnrollmentProfile,
})
if runtime.GOOS == "darwin" && shouldRunMigrator && mdmMigrator.CanRun() {
enrolled, enrollURL, err := profiles.IsEnrolledInMDM()
if err != nil {
log.Error().Err(err).Msg("fetching enrollment status to show mdm migrator")
continue
}
// enable tray items
migrateMDMItem.Enable()
migrateMDMItem.Show()
// we perform this check locally on the client too to avoid showing the
// dialog if the client has already migrated but the Fleet server
// doesn't know about this state yet.
enrolledIntoFleet, err := fleethttp.HostnamesMatch(enrollURL, fleetURL)
if err != nil {
log.Error().Err(err).Msg("comparing MDM server URLs")
continue
}
if !enrolledIntoFleet {
// isUnmanaged captures two important bits of information:
//
// - The notification coming from the server, which is based on information that's
// not available in the client (eg: is MDM configured? are migrations enabled?
// is this device elegible for migration?)
// - The current enrollment status of the device.
isUnmanaged := sum.Notifications.RenewEnrollmentProfile && !enrolled
forceModeEnabled := sum.Notifications.NeedsMDMMigration &&
sum.Config.MDM.MacOSMigration.Mode == fleet.MacOSMigrationModeForced
// if the device is unmanaged or we're
// in force mode and the device needs
// migration, enable aggressive mode.
if sum.Notifications.RenewEnrollmentProfile ||
(sum.Notifications.NeedsMDMMigration && sum.Config.MDM.MacOSMigration.Mode == fleet.MacOSMigrationModeForced) {
log.Info().Msg("MDM device is unmanaged or force mode enabled, automatically showing dialog")
if err := mdmMigrator.ShowInterval(); err != nil {
log.Error().Err(err).Msg("showing MDM migration dialog at interval")
// update org info in case it changed
mdmMigrator.SetProps(useraction.MDMMigratorProps{
OrgInfo: sum.Config.OrgInfo,
IsUnmanaged: isUnmanaged,
})
// enable tray items
migrateMDMItem.Enable()
migrateMDMItem.Show()
// if the device is unmanaged or we're in force mode and the device needs
// migration, enable aggressive mode.
if isUnmanaged || forceModeEnabled {
log.Info().Msg("MDM device is unmanaged or force mode enabled, automatically showing dialog")
if err := mdmMigrator.ShowInterval(); err != nil {
log.Error().Err(err).Msg("showing MDM migration dialog at interval")
}
}
}
} else {

View File

@ -11,6 +11,7 @@ import (
"os/exec"
"strings"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm/apple/mobileconfig"
)
@ -60,9 +61,30 @@ var execScript = func(script string) (*bytes.Buffer, error) {
// enrollment information and reports if the hostname of the MDM server
// supervising the device matches the hostname of the provided URL.
func IsEnrolledIntoMatchingURL(serverURL string) (bool, error) {
enrolled, currentURL, err := IsEnrolledInMDM()
if err != nil {
return false, fmt.Errorf("getting enrollment info: %w", err)
}
if !enrolled {
return false, nil
}
matches, err := fleethttp.HostnamesMatch(serverURL, currentURL)
if err != nil {
return false, fmt.Errorf("comparing URLs: %w", err)
}
return matches, nil
}
// IsEnrolledInMDM runs the `profiles` command to get the current MDM
// enrollment information and reports if the host is enrolled, and the URL of
// the MDM server (if enrolled)
func IsEnrolledInMDM() (bool, string, error) {
out, err := getMDMInfoFromProfilesCmd()
if err != nil {
return false, fmt.Errorf("calling /usr/bin/profiles: %w", err)
return false, "", fmt.Errorf("calling /usr/bin/profiles: %w", err)
}
// The output of the command is in the form:
@ -80,25 +102,17 @@ func IsEnrolledIntoMatchingURL(serverURL string) (bool, error) {
// 2. The last row matches our server URL
lines := bytes.Split(bytes.TrimSpace(out), []byte("\n"))
if len(lines) < 3 {
return false, nil
return false, "", nil
}
parts := bytes.SplitN(lines[2], []byte(":"), 2)
if len(parts) < 2 {
return false, fmt.Errorf("splitting profiles output to get MDM server URL: %w", err)
return false, "", fmt.Errorf("splitting profiles output to get MDM server URL: %w", err)
}
u, err := url.Parse(string(bytes.TrimSpace(parts[1])))
if err != nil {
return false, fmt.Errorf("parsing URL from profiles command: %w", err)
}
enrollmentURL := string(bytes.TrimSpace(parts[1]))
fu, err := url.Parse(serverURL)
if err != nil {
return false, fmt.Errorf("parsing provided Fleet URL: %w", err)
}
return u.Hostname() == fu.Hostname(), nil
return true, enrollmentURL, nil
}
// getMDMInfoFromProfilesCmd is declared as a variable so it can be overwritten by tests.

View File

@ -138,6 +138,79 @@ MDM server: https://valid.com/mdm/apple/mdm
}
}
func TestIsEnrolledInMDM(t *testing.T) {
cases := []struct {
cmdOut *string
cmdErr error
wantEnrolled bool
wantURL string
wantErr bool
}{
{nil, errors.New("test error"), false, "", true},
{ptr.String(""), nil, false, "", false},
{ptr.String(`
Enrolled via DEP: No
MDM enrollment: No
`), nil, false, "", false},
{
ptr.String(`
Enrolled via DEP: Yes
MDM enrollment: Yes
MDM server: https://test.example.com
`),
nil,
true,
"https://test.example.com",
false,
},
{
ptr.String(`
Enrolled via DEP: Yes
MDM enrollment: Yes
MDM server / https://test.example.com
`),
nil,
true,
"",
false,
},
{
ptr.String(`
Enrolled via DEP: Yes
MDM enrollment: Yes
MDM server: https://valid.com/mdm/apple/mdm
`),
nil,
true,
"https://valid.com/mdm/apple/mdm",
false,
},
}
origCmd := getMDMInfoFromProfilesCmd
t.Cleanup(func() { getMDMInfoFromProfilesCmd = origCmd })
for _, c := range cases {
getMDMInfoFromProfilesCmd = func() ([]byte, error) {
if c.cmdOut == nil {
return nil, c.cmdErr
}
var buf bytes.Buffer
buf.WriteString(*c.cmdOut)
return []byte(*c.cmdOut), nil
}
enrolled, url, err := IsEnrolledInMDM()
if c.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, c.wantEnrolled, enrolled)
require.Equal(t, c.wantURL, url)
}
}
func TestCheckAssignedEnrollmentProfile(t *testing.T) {
fleetURL := "https://valid.com"
cases := []struct {

View File

@ -8,6 +8,10 @@ func GetFleetdConfig() (*fleet.MDMAppleFleetdConfig, error) {
return nil, ErrNotImplemented
}
func IsEnrolledInMDM() (bool, string, error) {
return false, "", ErrNotImplemented
}
func IsEnrolledIntoMatchingURL(u string) (bool, error) {
return false, ErrNotImplemented
}

View File

@ -20,6 +20,13 @@ func TestIsEnrolledIntoMatchingURL(t *testing.T) {
require.False(t, enrolled)
}
func TestIsEnrolledInMDM(t *testing.T) {
enrolled, serverURL, err := IsEnrolledInMDM()
require.ErrorIs(t, ErrNotImplemented, err)
require.False(t, enrolled)
require.Empty(t, serverURL)
}
func TestCheckAssignedEnrollmentProfile(t *testing.T) {
err := CheckAssignedEnrollmentProfile("https://test.example.com")
require.ErrorIs(t, ErrNotImplemented, err)

View File

@ -12,7 +12,6 @@ import (
"text/template"
"time"
"github.com/fleetdm/fleet/v4/orbit/pkg/profiles"
"github.com/rs/zerolog/log"
)
@ -47,12 +46,11 @@ Please contact your IT admin [here]({{ .ContactURL }}).
// swiftDialog.
type baseDialog struct {
path string
fleetURL string
interruptCh chan struct{}
}
func newBaseDialog(path, fleetURL string) *baseDialog {
return &baseDialog{path: path, fleetURL: fleetURL, interruptCh: make(chan struct{})}
func newBaseDialog(path string) *baseDialog {
return &baseDialog{path: path, interruptCh: make(chan struct{})}
}
func (b *baseDialog) CanRun() bool {
@ -61,17 +59,7 @@ func (b *baseDialog) CanRun() bool {
return false
}
// we perform this check locally on the client too to avoid showing the
// dialog if the client has already migrated but the Fleet server
// doesn't know about this state yet.
enrolled, err := profiles.IsEnrolledIntoMatchingURL(b.fleetURL)
if err != nil {
log.Error().Err(err).Msg("fetching enrollment status to show swiftDialog")
return false
}
// only run the dialog if the host is not enrolled into Fleet
return !enrolled
return true
}
// Exit sends the interrupt signal to try and stop the current swiftDialog
@ -139,10 +127,10 @@ func (b *baseDialog) render(flags ...string) (chan swiftDialogExitCode, chan err
return exitCodeCh, errCh
}
func NewMDMMigrator(path, fleetURL string, frequency time.Duration, handler MDMMigratorHandler) MDMMigrator {
func NewMDMMigrator(path string, frequency time.Duration, handler MDMMigratorHandler) MDMMigrator {
return &swiftDialogMDMMigrator{
handler: handler,
baseDialog: newBaseDialog(path, fleetURL),
baseDialog: newBaseDialog(path),
frequency: frequency,
}
}

View File

@ -4,7 +4,7 @@ package useraction
import "time"
func NewMDMMigrator(path, fleetURL string, frequency time.Duration, handler MDMMigratorHandler) MDMMigrator {
func NewMDMMigrator(path string, frequency time.Duration, handler MDMMigratorHandler) MDMMigrator {
return &NoopMDMMigrator{}
}

View File

@ -5,7 +5,9 @@ package fleethttp
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"net/url"
"os"
"time"
@ -129,3 +131,19 @@ func NewGithubClient() *http.Client {
}
return NewClient()
}
// HostnamesMatch is an utility function to parse two strings as
// URLs and find if their hostnames match.
func HostnamesMatch(a, b string) (bool, error) {
ap, err := url.Parse(a)
if err != nil {
return false, fmt.Errorf("parsing URL %s: %w", a, err)
}
bp, err := url.Parse(b)
if err != nil {
return false, fmt.Errorf("parsing URL %s: %w", b, err)
}
return ap.Hostname() == bp.Hostname(), nil
}

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClient(t *testing.T) {
@ -69,3 +70,56 @@ func TestTransport(t *testing.T) {
})
}
}
func TestHostnamesMatch(t *testing.T) {
tests := []struct {
name string
inputA string
inputB string
expectedMatch bool
expectError bool
}{
{
name: "ValidHostnamesMatch",
inputA: "https://www.example.com/path",
inputB: "http://www.example.com:80",
expectedMatch: true,
expectError: false,
},
{
name: "ValidHostnamesDoNotMatch",
inputA: "https://www.example.com",
inputB: "https://sub.example.com",
expectedMatch: false,
expectError: false,
},
{
name: "InvalidURLA",
inputA: "ht tp://foo.com",
inputB: "https://www.example.com",
expectedMatch: false,
expectError: true,
},
{
name: "InvalidURLB",
inputA: "https://www.example.com",
inputB: "ht tp://foo.com",
expectedMatch: false,
expectError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
matched, err := HostnamesMatch(test.inputA, test.inputB)
if test.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expectedMatch, matched)
}
})
}
}

View File

@ -774,6 +774,16 @@ func (ds *Datastore) IngestMDMAppleDevicesFromDEPSync(ctx context.Context, devic
return createdCount, teamID, err
}
func (ds *Datastore) UpsertMDMAppleHostDEPAssignments(ctx context.Context, hosts []fleet.Host) error {
return ds.withTx(ctx, func(tx sqlx.ExtContext) error {
if err := upsertHostDEPAssignmentsDB(ctx, tx, hosts); err != nil {
return ctxerr.Wrap(ctx, err, "upsert host DEP assignments")
}
return nil
})
}
func upsertHostDEPAssignmentsDB(ctx context.Context, tx sqlx.ExtContext, hosts []fleet.Host) error {
if len(hosts) == 0 {
return nil

View File

@ -4172,7 +4172,7 @@ func (ds *Datastore) GetMatchingHostSerials(ctx context.Context, serials []strin
for _, serial := range serials {
args = append(args, serial)
}
stmt, args, err := sqlx.In("SELECT hardware_serial, team_id FROM hosts WHERE hardware_serial IN (?)", args)
stmt, args, err := sqlx.In("SELECT id, hardware_serial, team_id FROM hosts WHERE hardware_serial IN (?)", args)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "building IN statement for matching hosts")
}

View File

@ -7154,6 +7154,7 @@ func testGetMatchingHostSerials(t *testing.T, ds *Datastore) {
PrimaryMac: "30-65-EC-6F-C4-58",
HardwareSerial: serial,
TeamID: tmID,
ID: uint(i),
})
require.NoError(t, err)
}
@ -7170,7 +7171,7 @@ func testGetMatchingHostSerials(t *testing.T, ds *Datastore) {
"partial matches",
[]string{"foo", "rab"},
map[string]*fleet.Host{
"foo": {HardwareSerial: "foo", TeamID: nil},
"foo": {HardwareSerial: "foo", TeamID: nil, ID: 1},
},
"",
},
@ -7178,9 +7179,9 @@ func testGetMatchingHostSerials(t *testing.T, ds *Datastore) {
"all matching",
[]string{"foo", "bar", "baz"},
map[string]*fleet.Host{
"foo": {HardwareSerial: "foo", TeamID: nil},
"bar": {HardwareSerial: "bar", TeamID: &team.ID},
"baz": {HardwareSerial: "baz", TeamID: nil},
"foo": {HardwareSerial: "foo", TeamID: nil, ID: 1},
"bar": {HardwareSerial: "bar", TeamID: &team.ID, ID: 2},
"baz": {HardwareSerial: "baz", TeamID: nil, ID: 3},
},
"",
},

View File

@ -840,6 +840,10 @@ type Datastore interface {
// MDMAppleListDevices lists all the MDM enrolled devices.
MDMAppleListDevices(ctx context.Context) ([]MDMAppleDevice, error)
// UpsertMDMAppleHostDEPAssignments ensures there's an entry in
// `host_dep_assignments` for all the provided hosts.
UpsertMDMAppleHostDEPAssignments(ctx context.Context, hosts []Host) error
// IngestMDMAppleDevicesFromDEPSync creates new Fleet host records for MDM-enrolled devices that are
// not already enrolled in Fleet. It returns the number of hosts created, the team id that they
// joined (nil for no team), and an error.

View File

@ -472,11 +472,13 @@ func (d *DEPService) processDeviceResponse(ctx context.Context, depClient *godep
if len(existingSerials) > 0 {
level.Info(kitlog.With(d.logger)).Log("msg", "gathering existing serials to assign devices", "len", len(existingSerials))
serialsByTeam := map[*uint][]string{}
hosts := []fleet.Host{}
for _, host := range existingSerials {
if serialsByTeam[host.TeamID] == nil {
serialsByTeam[host.TeamID] = []string{}
}
serialsByTeam[host.TeamID] = append(serialsByTeam[host.TeamID], host.HardwareSerial)
hosts = append(hosts, *host)
}
for team, serials := range serialsByTeam {
profUUID, err := d.getProfileUUIDForTeam(ctx, team)
@ -489,6 +491,11 @@ func (d *DEPService) processDeviceResponse(ctx context.Context, depClient *godep
profileToSerials[profUUID] = append(profileToSerials[profUUID], serials...)
}
if err := d.ds.UpsertMDMAppleHostDEPAssignments(ctx, hosts); err != nil {
return ctxerr.Wrap(ctx, err, "upserting dep assignment for existing device")
}
} else {
level.Info(kitlog.With(d.logger)).Log("msg", "no existing devices to assign DEP profiles")
}

View File

@ -568,6 +568,8 @@ type BatchSetMDMAppleProfilesFunc func(ctx context.Context, tmID *uint, profiles
type MDMAppleListDevicesFunc func(ctx context.Context) ([]fleet.MDMAppleDevice, error)
type UpsertMDMAppleHostDEPAssignmentsFunc func(ctx context.Context, hosts []fleet.Host) error
type IngestMDMAppleDevicesFromDEPSyncFunc func(ctx context.Context, devices []godep.Device) (int64, *uint, error)
type IngestMDMAppleDeviceFromCheckinFunc func(ctx context.Context, mdmHost fleet.MDMAppleHostDetails) error
@ -1486,6 +1488,9 @@ type DataStore struct {
MDMAppleListDevicesFunc MDMAppleListDevicesFunc
MDMAppleListDevicesFuncInvoked bool
UpsertMDMAppleHostDEPAssignmentsFunc UpsertMDMAppleHostDEPAssignmentsFunc
UpsertMDMAppleHostDEPAssignmentsFuncInvoked bool
IngestMDMAppleDevicesFromDEPSyncFunc IngestMDMAppleDevicesFromDEPSyncFunc
IngestMDMAppleDevicesFromDEPSyncFuncInvoked bool
@ -3552,6 +3557,13 @@ func (s *DataStore) MDMAppleListDevices(ctx context.Context) ([]fleet.MDMAppleDe
return s.MDMAppleListDevicesFunc(ctx)
}
func (s *DataStore) UpsertMDMAppleHostDEPAssignments(ctx context.Context, hosts []fleet.Host) error {
s.mu.Lock()
s.UpsertMDMAppleHostDEPAssignmentsFuncInvoked = true
s.mu.Unlock()
return s.UpsertMDMAppleHostDEPAssignmentsFunc(ctx, hosts)
}
func (s *DataStore) IngestMDMAppleDevicesFromDEPSync(ctx context.Context, devices []godep.Device) (int64, *uint, error) {
s.mu.Lock()
s.IngestMDMAppleDevicesFromDEPSyncFuncInvoked = true