diff --git a/server/datastore/mysql/apple_mdm.go b/server/datastore/mysql/apple_mdm.go index 10d1b27fa..8432c068b 100644 --- a/server/datastore/mysql/apple_mdm.go +++ b/server/datastore/mysql/apple_mdm.go @@ -3181,7 +3181,7 @@ WHERE h.uuid = ? return nil } -func (ds *Datastore) batchSetMDMAppleDeclarations(ctx context.Context, tx sqlx.ExtContext, tmID *uint, declarations []*fleet.MDMAppleDeclaration) ([]*fleet.MDMAppleDeclaration, error) { +func (ds *Datastore) batchSetMDMAppleDeclarations(ctx context.Context, tx sqlx.ExtContext, tmID *uint, incomingDeclarations []*fleet.MDMAppleDeclaration) ([]*fleet.MDMAppleDeclaration, error) { const insertStmt = ` INSERT INTO mdm_apple_declarations ( declaration_uuid, @@ -3227,15 +3227,13 @@ WHERE declTeamID = *tmID } - var incomingLabels []fleet.ConfigurationProfileLabel - // build a list of identifiers for the incoming declarations, will keep the // existing ones if there's a match and no change - incomingIdents := make([]string, len(declarations)) + incomingIdents := make([]string, len(incomingDeclarations)) // at the same time, index the incoming declarations keyed by identifier for ease // or processing - incomingDecls := make(map[string]*fleet.MDMAppleDeclaration, len(declarations)) - for i, p := range declarations { + incomingDecls := make(map[string]*fleet.MDMAppleDeclaration, len(incomingDeclarations)) + for i, p := range incomingDeclarations { incomingIdents[i] = p.Identifier incomingDecls[p.Identifier] = p } @@ -3293,10 +3291,10 @@ WHERE if err == nil { err = errors.New(ds.testBatchSetMDMAppleProfilesErr) } - return nil, ctxerr.Wrap(ctx, err, "delete obsolete profiles") + return nil, ctxerr.Wrap(ctx, err, "delete obsolete declarations") } - for _, d := range declarations { + for _, d := range incomingDeclarations { checksum := md5ChecksumScriptContent(string(d.RawJSON)) declUUID := fleet.MDMAppleDeclarationUUIDPrefix + uuid.NewString() if _, err := tx.ExecContext(ctx, insertStmt, @@ -3311,11 +3309,36 @@ WHERE } return nil, ctxerr.Wrapf(ctx, err, "insert new/edited declaration with identifier %q", d.Identifier) } + } - d.DeclarationUUID = declUUID - for _, l := range d.Labels { - l.ProfileUUID = declUUID - incomingLabels = append(incomingLabels, l) + incomingLabels := []fleet.ConfigurationProfileLabel{} + if len(incomingIdents) > 0 { + var newlyInsertedDecls []*fleet.MDMAppleDeclaration + // load current declarations (again) that match the incoming declarations by name to grab their uuids + // this is an easy way to grab the identifiers for both the existing declarations and the new ones we generated. + // + // TODO(roberto): if we're a bit careful, we can harvest this + // information without this extra request in the previous DB + // calls. Due to time constraints, I'm leaving that + // optimization for a later iteration. + stmt, args, err := sqlx.In(loadExistingDecls, declTeamID, incomingIdents) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "build query to load newly inserted declarations") + } + if err := sqlx.SelectContext(ctx, tx, &newlyInsertedDecls, stmt, args...); err != nil { + return nil, ctxerr.Wrap(ctx, err, "load newly inserted declarations") + } + + for _, newlyInsertedDecl := range newlyInsertedDecls { + incomingDecl, ok := incomingDecls[newlyInsertedDecl.Identifier] + if !ok { + return nil, ctxerr.Wrapf(ctx, err, "declaration %q is in the database but was not incoming", newlyInsertedDecl.Identifier) + } + + for _, label := range incomingDecl.Labels { + label.ProfileUUID = newlyInsertedDecl.DeclarationUUID + incomingLabels = append(incomingLabels, label) + } } } @@ -3323,10 +3346,10 @@ WHERE if err == nil { err = errors.New(ds.testBatchSetMDMAppleProfilesErr) } - return nil, ctxerr.Wrap(ctx, err, "inserting apple profile label associations") + return nil, ctxerr.Wrap(ctx, err, "inserting apple declaration label associations") } - return declarations, nil + return incomingDeclarations, nil } func (ds *Datastore) NewMDMAppleDeclaration(ctx context.Context, declaration *fleet.MDMAppleDeclaration) (*fleet.MDMAppleDeclaration, error) { diff --git a/server/service/client.go b/server/service/client.go index f26bfcc05..bc8c9f17c 100644 --- a/server/service/client.go +++ b/server/service/client.go @@ -294,8 +294,9 @@ func getProfilesContents(baseDir string, profiles []fleet.MDMProfileSpec) ([]fle } // by default, use the file name. macOS profiles use their PayloadDisplayName - name := strings.TrimSuffix(filepath.Base(filePath), filepath.Ext(filePath)) - if mdm.GetRawProfilePlatform(fileContents) == "darwin" { + ext := filepath.Ext(filePath) + name := strings.TrimSuffix(filepath.Base(filePath), ext) + if mdm.GetRawProfilePlatform(fileContents) == "darwin" && ext == ".mobileconfig" { mc, err := fleet.NewMDMAppleConfigProfile(fileContents, nil) if err != nil { return nil, fmt.Errorf("applying fleet config: %w", err)