Add coverage to uncovered mysql code (#1855)

* Add coverage to uncovered mysql code

* Add deleted method and update mock

* Fix test
This commit is contained in:
Tomas Touceda 2021-09-07 13:48:04 -03:00 committed by GitHub
parent a6acb1cd1a
commit 7c34956d31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 536 additions and 248 deletions

View File

@ -150,3 +150,26 @@ func TestCleanupDistributedQueryCampaigns(t *testing.T) {
}
}
func TestSaveDistributedQueryCampaign(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
user := test.NewUser(t, ds, t.Name(), t.Name()+"zwass@fleet.co", true)
mockClock := clock.NewMockClock()
query := test.NewQuery(t, ds, t.Name()+"test", "select * from time", user.ID, false)
c1 := test.NewCampaign(t, ds, query.ID, fleet.QueryWaiting, mockClock.Now())
gotC, err := ds.DistributedQueryCampaign(c1.ID)
require.NoError(t, err)
require.Equal(t, fleet.QueryWaiting, gotC.Status)
c1.Status = fleet.QueryComplete
require.NoError(t, ds.SaveDistributedQueryCampaign(c1))
gotC, err = ds.DistributedQueryCampaign(c1.ID)
require.NoError(t, err)
require.Equal(t, fleet.QueryComplete, gotC.Status)
}

View File

@ -0,0 +1,74 @@
package mysql
import (
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDeleteEntity(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
host, err := ds.NewHost(&fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: t.Name() + "1",
UUID: t.Name() + "1",
OsqueryHostID: t.Name(),
Hostname: t.Name() + "foo.local",
PrimaryIP: "192.168.1.1",
PrimaryMac: "30-65-EC-6F-C4-58",
})
require.NoError(t, err)
require.NotNil(t, host)
require.NoError(t, ds.deleteEntity("hosts", host.ID))
host, err = ds.Host(host.ID)
require.Error(t, err)
assert.Nil(t, host)
}
func TestDeleteEntityByName(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
query1 := test.NewQuery(t, ds, t.Name()+"time", "select * from time", 0, true)
require.NoError(t, ds.deleteEntityByName("queries", query1.Name))
gotQ, err := ds.Query(query1.ID)
require.Error(t, err)
assert.Nil(t, gotQ)
}
func TestDeleteEntities(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
query1 := test.NewQuery(t, ds, t.Name()+"time1", "select * from time", 0, true)
query2 := test.NewQuery(t, ds, t.Name()+"time2", "select * from time", 0, true)
query3 := test.NewQuery(t, ds, t.Name()+"time3", "select * from time", 0, true)
count, err := ds.deleteEntities("queries", []uint{query1.ID, query2.ID})
require.NoError(t, err)
assert.Equal(t, uint(2), count)
gotQ, err := ds.Query(query1.ID)
require.Error(t, err)
assert.Nil(t, gotQ)
gotQ, err = ds.Query(query2.ID)
require.Error(t, err)
assert.Nil(t, gotQ)
gotQ, err = ds.Query(query3.ID)
require.NoError(t, err)
assert.Equal(t, query3.ID, gotQ.ID)
}

View File

@ -1282,7 +1282,7 @@ func TestAddHostsToTeam(t *testing.T) {
}
}
func TestSaveUsers(t *testing.T) {
func TestSaveHostUsers(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()

View File

@ -41,7 +41,6 @@ func TestCreateInvite(t *testing.T) {
}
func setupTestInvites(t *testing.T, ds fleet.Datastore) {
var err error
admin := &fleet.Invite{
Email: "admin@foo.com",
Name: "Xadmin",
@ -49,8 +48,8 @@ func setupTestInvites(t *testing.T, ds fleet.Datastore) {
GlobalRole: null.StringFrom("admin"),
}
admin, err = ds.NewInvite(admin)
require.Nil(t, err)
admin, err := ds.NewInvite(admin)
require.NoError(t, err)
for user := 0; user < 23; user++ {
i := fleet.Invite{
@ -62,7 +61,7 @@ func setupTestInvites(t *testing.T, ds fleet.Datastore) {
}
_, err := ds.NewInvite(&i)
assert.Nil(t, err, "Failure creating user", user)
require.NoError(t, err, "Failure creating user", user)
}
}
@ -184,3 +183,22 @@ func TestInviteByEmail(t *testing.T) {
})
}
}
func TestInvite(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
admin := &fleet.Invite{
Email: "admin@foo.com",
Name: "Xadmin",
Token: "admin",
GlobalRole: null.StringFrom("admin"),
}
admin, err := ds.NewInvite(admin)
require.NoError(t, err)
gotI, err := ds.Invite(admin.ID)
require.NoError(t, err)
assert.Equal(t, admin.ID, gotI.ID)
}

View File

@ -2,18 +2,14 @@ package mysql
import (
"bytes"
"database/sql"
"fmt"
"os/exec"
"testing"
"time"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/go-kit/kit/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -60,14 +56,6 @@ func TestMigrations(t *testing.T) {
}
func createMySQLDSForMigrationTests(t *testing.T, dbName string) *Datastore {
db, err := sql.Open(
"mysql",
fmt.Sprintf("%s:%s@tcp(%s)/?multiStatements=true", testUsername, testPassword, testAddress),
)
require.NoError(t, err)
_, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", dbName, dbName))
require.NoError(t, err)
// Create a datastore client in order to run migrations as usual
config := config.MysqlConfig{
Username: testUsername,
@ -75,7 +63,7 @@ func createMySQLDSForMigrationTests(t *testing.T, dbName string) *Datastore {
Address: testAddress,
Database: dbName,
}
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1))
ds, err := newDSWithConfig(t, dbName, config)
require.NoError(t, err)
return ds
}

View File

@ -262,14 +262,6 @@ func checkConfig(conf *config.MysqlConfig) error {
return nil
}
func (d *Datastore) Begin() (fleet.Transaction, error) {
return d.writer.Beginx()
}
func (d *Datastore) Name() string {
return "mysql"
}
func (d *Datastore) MigrateTables() error {
return tables.MigrationClient.Up(d.writer.DB, "")
}
@ -316,45 +308,6 @@ func (d *Datastore) MigrationStatus() (fleet.MigrationStatus, error) {
}
}
// Drop removes database
func (d *Datastore) Drop() error {
tables := []struct {
Name string `db:"TABLE_NAME"`
}{}
sql := `
SELECT TABLE_NAME
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = ?;
`
if err := d.writer.Select(&tables, sql, d.config.Database); err != nil {
return err
}
tx, err := d.writer.Begin()
if err != nil {
return err
}
_, err = tx.Exec("SET FOREIGN_KEY_CHECKS = 0")
if err != nil {
return tx.Rollback()
}
for _, table := range tables {
_, err = tx.Exec(fmt.Sprintf("DROP TABLE %s;", table.Name))
if err != nil {
return tx.Rollback()
}
}
_, err = tx.Exec("SET FOREIGN_KEY_CHECKS = 1")
if err != nil {
return tx.Rollback()
}
return tx.Commit()
}
// HealthCheck returns an error if the MySQL backend is not healthy.
func (d *Datastore) HealthCheck() error {
if _, err := d.writer.Exec("select 1"); err != nil {

View File

@ -1,12 +1,23 @@
package mysql
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"database/sql"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/VividCortex/mysqlerr"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-kit/kit/log"
@ -511,3 +522,226 @@ func TestWhereOmitIDs(t *testing.T) {
})
}
}
func TestWithRetryTxWithRollback(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withRetryTxx(func(tx *sqlx.Tx) error {
_, err := tx.Exec("SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxWillRollbackWhenPanic(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
defer func() { recover() }()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withRetryTxx(func(tx *sqlx.Tx) error {
panic("ROLLBACK")
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithTxWithRollback(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withTx(func(tx *sqlx.Tx) error {
_, err := tx.Exec("SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithTxWillRollbackWhenPanic(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
defer func() { recover() }()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withTx(func(tx *sqlx.Tx) error {
panic("ROLLBACK")
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestNewReadsPasswordFromDisk(t *testing.T) {
passwordFile, err := os.CreateTemp(t.TempDir(), "*.passwordtest")
require.NoError(t, err)
_, err = passwordFile.WriteString(testPassword)
require.NoError(t, err)
passwordPath := passwordFile.Name()
require.NoError(t, passwordFile.Close())
dbName := t.Name()
// Create a datastore client in order to run migrations as usual
mysqlConfig := config.MysqlConfig{
Username: testUsername,
Password: "",
PasswordPath: passwordPath,
Address: testAddress,
Database: dbName,
}
ds, err := newDSWithConfig(t, dbName, mysqlConfig)
require.NoError(t, err)
defer ds.Close()
require.NoError(t, ds.HealthCheck())
}
func newDSWithConfig(t *testing.T, dbName string, config config.MysqlConfig) (*Datastore, error) {
db, err := sql.Open(
"mysql",
fmt.Sprintf("%s:%s@tcp(%s)/?multiStatements=true", testUsername, testPassword, testAddress),
)
require.NoError(t, err)
_, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", dbName, dbName))
require.NoError(t, err)
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1))
return ds, err
}
func generateTestCert(t *testing.T) (string, string) {
privateKeyCA, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"aa"},
},
NotBefore: time.Now().Add(-1 * time.Duration(24) * time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKeyCA.PublicKey, privateKeyCA)
require.NoError(t, err)
publicPem, err := os.CreateTemp(t.TempDir(), "*-ca.pem")
require.NoError(t, err)
require.NoError(t, pem.Encode(publicPem, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
require.NoError(t, publicPem.Close())
keyPem, err := os.CreateTemp(t.TempDir(), "*-key.pem")
require.NoError(t, err)
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKeyCA)
require.NoError(t, pem.Encode(keyPem, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes}))
require.NoError(t, keyPem.Close())
return publicPem.Name(), keyPem.Name()
}
func TestNewUsesRegisterTLS(t *testing.T) {
dbName := t.Name()
ca, _ := generateTestCert(t)
cert, key := generateTestCert(t)
mysqlConfig := config.MysqlConfig{
Username: testUsername,
Password: testPassword,
Address: testAddress,
Database: dbName,
TLSCA: ca,
TLSCert: cert,
TLSKey: key,
}
// This fails because the certificate mysql is using is different than the one generated here
_, err := newDSWithConfig(t, dbName, mysqlConfig)
require.Error(t, err)
require.Equal(t, "x509: certificate is not valid for any names, but wanted to match localhost", err.Error())
}
func TestWhereFilterTeas(t *testing.T) {
t.Parallel()
testCases := []struct {
filter fleet.TeamFilter
expected string
}{
// No teams or global role
{
filter: fleet.TeamFilter{User: nil},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: false,
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}}},
expected: "t.id IN (1)",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}},
expected: "t.id IN (1)",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
IncludeObserver: true,
},
expected: "t.id IN (1)",
},
}
for _, tt := range testCases {
tt := tt
t.Run("", func(t *testing.T) {
t.Parallel()
ds := &Datastore{logger: log.NewNopLogger()}
sql := ds.whereFilterTeams(tt.filter, "t")
assert.Equal(t, tt.expected, sql)
})
}
}

View File

@ -22,38 +22,6 @@ func (d *Datastore) NewPasswordResetRequest(req *fleet.PasswordResetRequest) (*f
}
func (d *Datastore) SavePasswordResetRequest(req *fleet.PasswordResetRequest) error {
sqlStatement := `
UPDATE password_reset_requests SET
expires_at = ?,
user_id = ?,
token = ?
WHERE id = ?
`
result, err := d.writer.Exec(sqlStatement, req.ExpiresAt, req.UserID, req.Token, req.ID)
if err != nil {
return errors.Wrap(err, "updating password reset requests")
}
rows, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "rows affected updating password reset requests")
}
if rows == 0 {
return notFound("PasswordResetRequest").WithID(req.ID)
}
return nil
}
func (d *Datastore) DeletePasswordResetRequest(req *fleet.PasswordResetRequest) error {
err := d.deleteEntity("password_reset_requests", req.ID)
if err != nil {
return errors.Wrap(err, "deleting from password reset request")
}
return nil
}
func (d *Datastore) DeletePasswordResetRequestsForUser(userID uint) error {
sqlStatement := `
DELETE FROM password_reset_requests WHERE user_id = ?
@ -65,42 +33,11 @@ func (d *Datastore) DeletePasswordResetRequestsForUser(userID uint) error {
return nil
}
func (d *Datastore) FindPassswordResetByID(id uint) (*fleet.PasswordResetRequest, error) {
sqlStatement := `
SELECT * FROM password_reset_requests
WHERE id = ? LIMIT 1
`
passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.reader.Get(&passwordResetRequest, sqlStatement, id)
if err != nil {
return nil, errors.Wrap(err, "selecting password reset by id")
}
return passwordResetRequest, nil
}
func (d *Datastore) FindPassswordResetsByUserID(id uint) ([]*fleet.PasswordResetRequest, error) {
sqlStatement := `
SELECT * FROM password_reset_requests
WHERE user_id = ?
`
passwordResetRequests := []*fleet.PasswordResetRequest{}
err := d.reader.Select(&passwordResetRequests, sqlStatement, id)
if err != nil {
return nil, errors.Wrap(err, "finding password resets by user id")
}
return passwordResetRequests, nil
}
func (d *Datastore) FindPassswordResetByToken(token string) (*fleet.PasswordResetRequest, error) {
sqlStatement := `
SELECT * FROM password_reset_requests
WHERE token = ? LIMIT 1
`
SELECT * FROM password_reset_requests
WHERE token = ? LIMIT 1
`
passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.reader.Get(passwordResetRequest, sqlStatement, token)
if err != nil {
@ -110,18 +47,3 @@ func (d *Datastore) FindPassswordResetByToken(token string) (*fleet.PasswordRese
return passwordResetRequest, nil
}
func (d *Datastore) FindPassswordResetByTokenAndUserID(token string, id uint) (*fleet.PasswordResetRequest, error) {
sqlStatement := `
SELECT * FROM password_reset_requests
WHERE user_id = ? AND token = ?
LIMIT 1
`
passwordResetRequest := &fleet.PasswordResetRequest{}
err := d.reader.Get(passwordResetRequest, sqlStatement, id, token)
if err != nil {
return nil, errors.Wrap(err, "selecting password reset by token and user id")
}
return passwordResetRequest, nil
}

View File

@ -46,7 +46,6 @@ func (d *Datastore) ListSessionsForUser(id uint) ([]*fleet.Session, error) {
}
return sessions, nil
}
func (d *Datastore) NewSession(session *fleet.Session) (*fleet.Session, error) {

View File

@ -0,0 +1,54 @@
package mysql
import (
"testing"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSessionGetters(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
user, err := ds.NewUser(&fleet.User{
Password: []byte("supersecret"),
Email: "other@bobcom",
GlobalRole: ptr.String(fleet.RoleObserver),
})
require.NoError(t, err)
session, err := ds.NewSession(&fleet.Session{UserID: user.ID, Key: "somekey"})
require.NoError(t, err)
require.NotZero(t, session.ID)
gotByID, err := ds.SessionByID(session.ID)
require.NoError(t, err)
assert.Equal(t, session.Key, gotByID.Key)
gotByKey, err := ds.SessionByKey(session.Key)
require.NoError(t, err)
assert.Equal(t, session.ID, gotByKey.ID)
newSession, err := ds.NewSession(&fleet.Session{UserID: user.ID, Key: "somekey2"})
require.NoError(t, err)
sessions, err := ds.ListSessionsForUser(user.ID)
require.NoError(t, err)
require.Len(t, sessions, 2)
require.NoError(t, ds.DestroySession(session))
prevAccessedAt := newSession.AccessedAt
require.NoError(t, ds.MarkSessionAccessed(newSession))
sessions, err = ds.ListSessionsForUser(user.ID)
require.NoError(t, err)
require.Len(t, sessions, 1)
require.NotEqual(t, prevAccessedAt, sessions[0].AccessedAt)
require.NoError(t, ds.DestroyAllSessionsForUser(user.ID))
}

View File

@ -121,6 +121,7 @@ func TestSoftwareCPE(t *testing.T) {
require.NoError(t, err)
iterator, err := ds.AllSoftwareWithoutCPEIterator()
defer iterator.Close()
require.NoError(t, err)
loops := 0
@ -143,11 +144,13 @@ func TestSoftwareCPE(t *testing.T) {
loops++
}
assert.Equal(t, len(host1.Software), loops)
require.NoError(t, iterator.Close())
err = ds.AddCPEForSoftware(fleet.Software{ID: id}, "some:cpe")
require.NoError(t, err)
iterator, err = ds.AllSoftwareWithoutCPEIterator()
defer iterator.Close()
require.NoError(t, err)
loops = 0
@ -169,6 +172,7 @@ func TestSoftwareCPE(t *testing.T) {
loops++
}
assert.Equal(t, len(host1.Software)-1, loops)
require.NoError(t, iterator.Close())
}
func TestInsertCVEs(t *testing.T) {
@ -291,6 +295,23 @@ func TestAllCPEs(t *testing.T) {
assert.ElementsMatch(t, cpes, []string{"somecpe", "someothercpewithoutvulns"})
}
func TestNothingChanged(t *testing.T) {
assert.False(t, nothingChanged(nil, []fleet.Software{{}}))
assert.True(t, nothingChanged(nil, nil))
assert.True(t, nothingChanged(
[]fleet.Software{{Name: "A", Version: "1.0", Source: "ASD"}},
[]fleet.Software{{Name: "A", Version: "1.0", Source: "ASD"}},
))
assert.False(t, nothingChanged(
[]fleet.Software{{Name: "A", Version: "1.1", Source: "ASD"}},
[]fleet.Software{{Name: "A", Version: "1.0", Source: "ASD"}},
))
assert.False(t, nothingChanged(
[]fleet.Software{{Name: "A", Version: "1.0", Source: "ASD"}},
[]fleet.Software{{Name: "A", Version: "1.0", Source: "ASD"}, {Name: "B", Version: "1.0", Source: "ASD"}},
))
}
func TestLoadSupportsTonsOfCVEs(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()

View File

@ -300,3 +300,40 @@ func TestHostIDsInTargets(t *testing.T) {
require.Nil(t, err)
assert.Equal(t, []uint{1, 3, 4, 5, 6}, ids)
}
func TestHostIDsInTargetsTeam(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
user := &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}
filter := fleet.TeamFilter{User: user}
mockClock := clock.NewMockClock()
hostCount := 0
initHost := func(seenTime time.Time, distributedInterval uint, configTLSRefresh uint, teamID *uint) *fleet.Host {
hostCount += 1
h, err := ds.NewHost(&fleet.Host{
OsqueryHostID: strconv.Itoa(hostCount),
DetailUpdatedAt: mockClock.Now(),
LabelUpdatedAt: mockClock.Now(),
SeenTime: mockClock.Now(),
NodeKey: strconv.Itoa(hostCount),
DistributedInterval: distributedInterval,
ConfigTLSRefresh: configTLSRefresh,
TeamID: teamID,
})
require.Nil(t, err)
require.Nil(t, ds.MarkHostSeen(h, seenTime))
return h
}
team1, err := ds.NewTeam(&fleet.Team{Name: t.Name() + "team1"})
require.NoError(t, err)
h1 := initHost(mockClock.Now().Add(-1*time.Second), 10, 60, &team1.ID)
targets, err := ds.HostIDsInTargets(filter, fleet.HostTargets{TeamIDs: []uint{team1.ID}})
require.NoError(t, err)
assert.Equal(t, []uint{h1.ID}, targets)
}

View File

@ -205,3 +205,26 @@ func TestTeamSearchTeams(t *testing.T) {
require.NoError(t, err)
assert.Len(t, teams, 0)
}
func TestTeamEnrollSecrets(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
secrets := []*fleet.EnrollSecret{{Secret: "secret1"}, {Secret: "secret2"}}
team1, err := ds.NewTeam(&fleet.Team{
Name: "team1",
Secrets: secrets,
})
require.NoError(t, err)
enrollSecrets, err := ds.TeamEnrollSecrets(team1.ID)
require.NoError(t, err)
var justSecrets []*fleet.EnrollSecret
for _, secret := range enrollSecrets {
require.NotNil(t, secret.TeamID)
assert.Equal(t, team1.ID, *secret.TeamID)
justSecrets = append(justSecrets, &fleet.EnrollSecret{Secret: secret.Secret})
}
test.ElementsMatchSkipTimestampsID(t, secrets, justSecrets)
}

View File

@ -2,9 +2,12 @@ package mysql
import (
"fmt"
"github.com/fleetdm/fleet/v4/server"
"strings"
"testing"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/test"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/stretchr/testify/assert"
@ -298,3 +301,30 @@ func TestUserCreateWithTeams(t *testing.T) {
assert.Equal(t, uint(9), user.Teams[2].ID)
assert.Equal(t, "maintainer", user.Teams[2].Role)
}
func TestSaveUsers(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
u1 := test.NewUser(t, ds, t.Name()+"Admin1", t.Name()+"admin1@fleet.co", true)
u2 := test.NewUser(t, ds, t.Name()+"Admin2", t.Name()+"admin2@fleet.co", true)
u3 := test.NewUser(t, ds, t.Name()+"Admin3", t.Name()+"admin3@fleet.co", true)
u1.Email += "m"
u2.Email += "m"
u3.Email += "m"
require.NoError(t, ds.SaveUsers([]*fleet.User{u1, u2, u3}))
gotU1, err := ds.UserByID(u1.ID)
require.NoError(t, err)
assert.True(t, strings.HasSuffix(gotU1.Email, "fleet.com"))
gotU2, err := ds.UserByID(u3.ID)
require.NoError(t, err)
assert.True(t, strings.HasSuffix(gotU2.Email, "fleet.com"))
gotU3, err := ds.UserByID(u3.ID)
require.NoError(t, err)
assert.True(t, strings.HasSuffix(gotU3.Email, "fleet.com"))
}

View File

@ -212,13 +212,8 @@ type Datastore interface {
// PasswordResetStore manages password resets in the Datastore
NewPasswordResetRequest(req *PasswordResetRequest) (*PasswordResetRequest, error)
SavePasswordResetRequest(req *PasswordResetRequest) error
DeletePasswordResetRequest(req *PasswordResetRequest) error
DeletePasswordResetRequestsForUser(userID uint) error
FindPassswordResetByID(id uint) (*PasswordResetRequest, error)
FindPassswordResetsByUserID(id uint) ([]*PasswordResetRequest, error)
FindPassswordResetByToken(token string) (*PasswordResetRequest, error)
FindPassswordResetByTokenAndUserID(token string, id uint) (*PasswordResetRequest, error)
///////////////////////////////////////////////////////////////////////////////
// SessionStore is the abstract interface that all session backends must conform to.
@ -345,15 +340,12 @@ type Datastore interface {
PolicyQueriesForHost(host *Host) (map[string]string, error)
Name() string
Drop() error
// MigrateTables creates and migrates the table schemas
MigrateTables() error
// MigrateData populates built-in data
MigrateData() error
// MigrationStatus returns nil if migrations are complete, and an error if migrations need to be run.
MigrationStatus() (MigrationStatus, error)
Begin() (Transaction, error)
}
type MigrationStatus int

View File

@ -168,20 +168,10 @@ type HostIDsInTargetsFunc func(filter fleet.TeamFilter, targets fleet.HostTarget
type NewPasswordResetRequestFunc func(req *fleet.PasswordResetRequest) (*fleet.PasswordResetRequest, error)
type SavePasswordResetRequestFunc func(req *fleet.PasswordResetRequest) error
type DeletePasswordResetRequestFunc func(req *fleet.PasswordResetRequest) error
type DeletePasswordResetRequestsForUserFunc func(userID uint) error
type FindPassswordResetByIDFunc func(id uint) (*fleet.PasswordResetRequest, error)
type FindPassswordResetsByUserIDFunc func(id uint) ([]*fleet.PasswordResetRequest, error)
type FindPassswordResetByTokenFunc func(token string) (*fleet.PasswordResetRequest, error)
type FindPassswordResetByTokenAndUserIDFunc func(token string, id uint) (*fleet.PasswordResetRequest, error)
type SessionByKeyFunc func(key string) (*fleet.Session, error)
type SessionByIDFunc func(id uint) (*fleet.Session, error)
@ -280,18 +270,12 @@ type DeleteGlobalPoliciesFunc func(ids []uint) ([]uint, error)
type PolicyQueriesForHostFunc func(host *fleet.Host) (map[string]string, error)
type NameFunc func() string
type DropFunc func() error
type MigrateTablesFunc func() error
type MigrateDataFunc func() error
type MigrationStatusFunc func() (fleet.MigrationStatus, error)
type BeginFunc func() (fleet.Transaction, error)
type DataStore struct {
NewCarveFunc NewCarveFunc
NewCarveFuncInvoked bool
@ -530,27 +514,12 @@ type DataStore struct {
NewPasswordResetRequestFunc NewPasswordResetRequestFunc
NewPasswordResetRequestFuncInvoked bool
SavePasswordResetRequestFunc SavePasswordResetRequestFunc
SavePasswordResetRequestFuncInvoked bool
DeletePasswordResetRequestFunc DeletePasswordResetRequestFunc
DeletePasswordResetRequestFuncInvoked bool
DeletePasswordResetRequestsForUserFunc DeletePasswordResetRequestsForUserFunc
DeletePasswordResetRequestsForUserFuncInvoked bool
FindPassswordResetByIDFunc FindPassswordResetByIDFunc
FindPassswordResetByIDFuncInvoked bool
FindPassswordResetsByUserIDFunc FindPassswordResetsByUserIDFunc
FindPassswordResetsByUserIDFuncInvoked bool
FindPassswordResetByTokenFunc FindPassswordResetByTokenFunc
FindPassswordResetByTokenFuncInvoked bool
FindPassswordResetByTokenAndUserIDFunc FindPassswordResetByTokenAndUserIDFunc
FindPassswordResetByTokenAndUserIDFuncInvoked bool
SessionByKeyFunc SessionByKeyFunc
SessionByKeyFuncInvoked bool
@ -698,12 +667,6 @@ type DataStore struct {
PolicyQueriesForHostFunc PolicyQueriesForHostFunc
PolicyQueriesForHostFuncInvoked bool
NameFunc NameFunc
NameFuncInvoked bool
DropFunc DropFunc
DropFuncInvoked bool
MigrateTablesFunc MigrateTablesFunc
MigrateTablesFuncInvoked bool
@ -712,9 +675,6 @@ type DataStore struct {
MigrationStatusFunc MigrationStatusFunc
MigrationStatusFuncInvoked bool
BeginFunc BeginFunc
BeginFuncInvoked bool
}
func (s *DataStore) NewCarve(metadata *fleet.CarveMetadata) (*fleet.CarveMetadata, error) {
@ -1112,41 +1072,16 @@ func (s *DataStore) NewPasswordResetRequest(req *fleet.PasswordResetRequest) (*f
return s.NewPasswordResetRequestFunc(req)
}
func (s *DataStore) SavePasswordResetRequest(req *fleet.PasswordResetRequest) error {
s.SavePasswordResetRequestFuncInvoked = true
return s.SavePasswordResetRequestFunc(req)
}
func (s *DataStore) DeletePasswordResetRequest(req *fleet.PasswordResetRequest) error {
s.DeletePasswordResetRequestFuncInvoked = true
return s.DeletePasswordResetRequestFunc(req)
}
func (s *DataStore) DeletePasswordResetRequestsForUser(userID uint) error {
s.DeletePasswordResetRequestsForUserFuncInvoked = true
return s.DeletePasswordResetRequestsForUserFunc(userID)
}
func (s *DataStore) FindPassswordResetByID(id uint) (*fleet.PasswordResetRequest, error) {
s.FindPassswordResetByIDFuncInvoked = true
return s.FindPassswordResetByIDFunc(id)
}
func (s *DataStore) FindPassswordResetsByUserID(id uint) ([]*fleet.PasswordResetRequest, error) {
s.FindPassswordResetsByUserIDFuncInvoked = true
return s.FindPassswordResetsByUserIDFunc(id)
}
func (s *DataStore) FindPassswordResetByToken(token string) (*fleet.PasswordResetRequest, error) {
s.FindPassswordResetByTokenFuncInvoked = true
return s.FindPassswordResetByTokenFunc(token)
}
func (s *DataStore) FindPassswordResetByTokenAndUserID(token string, id uint) (*fleet.PasswordResetRequest, error) {
s.FindPassswordResetByTokenAndUserIDFuncInvoked = true
return s.FindPassswordResetByTokenAndUserIDFunc(token, id)
}
func (s *DataStore) SessionByKey(key string) (*fleet.Session, error) {
s.SessionByKeyFuncInvoked = true
return s.SessionByKeyFunc(key)
@ -1392,16 +1327,6 @@ func (s *DataStore) PolicyQueriesForHost(host *fleet.Host) (map[string]string, e
return s.PolicyQueriesForHostFunc(host)
}
func (s *DataStore) Name() string {
s.NameFuncInvoked = true
return s.NameFunc()
}
func (s *DataStore) Drop() error {
s.DropFuncInvoked = true
return s.DropFunc()
}
func (s *DataStore) MigrateTables() error {
s.MigrateTablesFuncInvoked = true
return s.MigrateTablesFunc()
@ -1416,8 +1341,3 @@ func (s *DataStore) MigrationStatus() (fleet.MigrationStatus, error) {
s.MigrationStatusFuncInvoked = true
return s.MigrationStatusFunc()
}
func (s *DataStore) Begin() (fleet.Transaction, error) {
s.BeginFuncInvoked = true
return s.BeginFunc()
}

View File

@ -21,11 +21,11 @@ func NewQuery(t *testing.T, ds fleet.Datastore, name, q string, authorID uint, s
AuthorID: authorPtr,
Saved: saved,
})
require.Nil(t, err)
require.NoError(t, err)
// Loading gives us the timestamps
query, err = ds.Query(query.ID)
require.Nil(t, err)
require.NoError(t, err)
return query
}
@ -37,7 +37,7 @@ func NewPack(t *testing.T, ds fleet.Datastore, name string) *fleet.Pack {
// Loading gives us the timestamps
pack, ok, err := ds.PackByName(name)
require.True(t, ok)
require.Nil(t, err)
require.NoError(t, err)
return pack
}
@ -52,11 +52,11 @@ func NewCampaign(t *testing.T, ds fleet.Datastore, queryID uint, status fleet.Di
QueryID: queryID,
Status: status,
})
require.Nil(t, err)
require.NoError(t, err)
// Loading gives us the timestamps
campaign, err = ds.DistributedQueryCampaign(campaign.ID)
require.Nil(t, err)
require.NoError(t, err)
return campaign
}
@ -68,7 +68,7 @@ func AddHostToCampaign(t *testing.T, ds fleet.Datastore, campaignID, hostID uint
TargetID: hostID,
DistributedQueryCampaignID: campaignID,
})
require.Nil(t, err)
require.NoError(t, err)
}
func AddLabelToCampaign(t *testing.T, ds fleet.Datastore, campaignID, labelID uint) {
@ -78,7 +78,7 @@ func AddLabelToCampaign(t *testing.T, ds fleet.Datastore, campaignID, labelID ui
TargetID: labelID,
DistributedQueryCampaignID: campaignID,
})
require.Nil(t, err)
require.NoError(t, err)
}
func AddAllHostsLabel(t *testing.T, ds fleet.Datastore) {
@ -90,7 +90,7 @@ func AddAllHostsLabel(t *testing.T, ds fleet.Datastore) {
LabelMembershipType: fleet.LabelMembershipTypeManual,
},
)
require.Nil(t, err)
require.NoError(t, err)
}
func NewHost(t *testing.T, ds fleet.Datastore, name, ip, key, uuid string, now time.Time) *fleet.Host {
@ -105,9 +105,9 @@ func NewHost(t *testing.T, ds fleet.Datastore, name, ip, key, uuid string, now t
OsqueryHostID: osqueryHostID,
})
require.Nil(t, err)
require.NoError(t, err)
require.NotZero(t, h.ID)
require.Nil(t, ds.MarkHostSeen(h, now))
require.NoError(t, ds.MarkHostSeen(h, now))
return h
}
@ -125,7 +125,7 @@ func NewUser(t *testing.T, ds fleet.Datastore, name, email string, admin bool) *
GlobalRole: &role,
})
require.Nil(t, err)
require.NoError(t, err)
require.NotZero(t, u.ID)
return u
@ -140,7 +140,7 @@ func NewScheduledQuery(t *testing.T, ds fleet.Datastore, pid, qid, interval uint
Snapshot: &snapshot,
Removed: &removed,
})
require.Nil(t, err)
require.NoError(t, err)
require.NotZero(t, sq.ID)
return sq