fleet/server/datastore/mysql/mysql_test.go
Martin Angers 9082438580
Feature: Remote Lock for macOS, Windows and Linux (#16783)
Feature branch for the #9949  story.

---------

Co-authored-by: Jahziel Villasana-Espinoza <jahziel@fleetdm.com>
Co-authored-by: Roberto Dip <me@roperzh.com>
Co-authored-by: Gabriel Hernandez <ghernandez345@gmail.com>
Co-authored-by: Sarah Gillespie <sarah@fleetdm.com>
2024-02-13 13:03:53 -05:00

1035 lines
26 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mysql
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"database/sql"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"os"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/VividCortex/mysqlerr"
"github.com/WatchBeam/clock"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-kit/kit/log"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDatastoreReplica(t *testing.T) {
// a bit unfortunate to create temp databases just for this - could be mixed
// with other tests when/if we move to subtests to minimize the number of
// databases created for tests (see #1805).
ctx := context.Background()
t.Run("noreplica", func(t *testing.T) {
ds := CreateMySQLDSWithOptions(t, nil)
defer ds.Close()
require.Equal(t, ds.reader(ctx), ds.writer(ctx))
})
t.Run("replica", func(t *testing.T) {
opts := &DatastoreTestOptions{Replica: true}
ds := CreateMySQLDSWithOptions(t, opts)
defer ds.Close()
require.NotEqual(t, ds.reader(ctx), ds.writer(ctx))
// create a new host
host, err := ds.NewHost(ctx, &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String("1"),
UUID: "1",
Hostname: "foo.local",
PrimaryIP: "192.168.1.1",
PrimaryMac: "30-65-EC-6F-C4-58",
})
require.NoError(t, err)
require.NotNil(t, host)
// trying to read it fails, not replicated yet
_, err = ds.Host(ctx, host.ID)
require.Error(t, err)
require.True(t, errors.Is(err, sql.ErrNoRows), err)
// force read from primary works
ctx = ctxdb.RequirePrimary(ctx, true)
got, err := ds.Host(ctx, host.ID)
require.NoError(t, err)
require.Equal(t, host.ID, got.ID)
// but from replica still fails
ctx = ctxdb.RequirePrimary(ctx, false)
_, err = ds.Host(ctx, host.ID)
require.Error(t, err)
require.True(t, errors.Is(err, sql.ErrNoRows))
opts.RunReplication()
// now it can read it from replica
got, err = ds.Host(ctx, host.ID)
require.NoError(t, err)
require.Equal(t, host.ID, got.ID)
})
}
func TestSanitizeColumn(t *testing.T) {
t.Parallel()
testCases := []struct {
input string
output string
}{
{"", ""},
{"foobar-column", "`foobar-column`"},
{"foobar_column", "`foobar_column`"},
{"foobar;column", "`foobarcolumn`"},
{"foobar#", "`foobar`"},
{"foobar*baz", "`foobarbaz`"},
{"....", ""},
{"h.id", "`h`.`id`"},
{"id;delete from hosts", "`iddeletefromhosts`"},
{"select * from foo", "`selectfromfoo`"},
}
for _, tt := range testCases {
t.Run(tt.input, func(t *testing.T) {
require.Equal(t, tt.output, sanitizeColumn(tt.input))
})
}
}
func TestSearchLike(t *testing.T) {
t.Parallel()
testCases := []struct {
inSQL string
inParams []interface{}
match string
columns []string
outSQL string
outParams []interface{}
}{
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{},
match: "foobar",
columns: []string{"hostname"},
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ?)",
outParams: []interface{}{"%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{3},
match: "foobar",
columns: []string{},
outSQL: "SELECT * FROM HOSTS WHERE TRUE",
outParams: []interface{}{3},
},
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{1},
match: "foobar",
columns: []string{"hostname"},
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ?)",
outParams: []interface{}{1, "%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{1},
match: "foobar",
columns: []string{"hostname", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%foobar%", "%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE TRUE",
inParams: []interface{}{1},
match: "foobar",
columns: []string{"hostname", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE TRUE AND (hostname LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%foobar%", "%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
inParams: []interface{}{1},
match: "forty_%",
columns: []string{"ipv4", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%forty\\_\\%%", "%forty\\_\\%%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
inParams: []interface{}{1},
match: "forty_%",
columns: []string{"ipv4", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%forty\\_\\%%", "%forty\\_\\%%"},
},
{
inSQL: "SELECT * FROM HOSTS WHERE 1=1",
inParams: []interface{}{1},
match: "a@b.c",
columns: []string{"ipv4", "uuid"},
outSQL: "SELECT * FROM HOSTS WHERE 1=1 AND (ipv4 LIKE ? OR uuid LIKE ?)",
outParams: []interface{}{1, "%a@b.c%", "%a@b.c%"},
},
}
for _, tt := range testCases {
t.Run("", func(t *testing.T) {
sql, params := searchLike(tt.inSQL, tt.inParams, tt.match, tt.columns...)
assert.Equal(t, tt.outSQL, sql)
assert.Equal(t, tt.outParams, params)
})
}
}
func TestHostSearchLike(t *testing.T) {
t.Parallel()
testCases := []struct {
inSQL string
inParams []interface{}
match string
columns []string
outSQL string
outParams []interface{}
}{
{
inSQL: "SELECT * FROM HOSTS h WHERE TRUE",
inParams: []interface{}{},
match: "foobar",
columns: []string{"hostname"},
outSQL: "SELECT * FROM HOSTS h WHERE TRUE AND (hostname LIKE ?)",
outParams: []interface{}{"%foobar%"},
},
{
inSQL: "SELECT * FROM HOSTS h WHERE 1=1",
inParams: []interface{}{1},
match: "a@b.c",
columns: []string{"ipv4"},
outSQL: "SELECT * FROM HOSTS h WHERE 1=1 AND (ipv4 LIKE ? OR ( EXISTS (SELECT 1 FROM host_emails he WHERE he.host_id = h.id AND he.email LIKE ?)))",
outParams: []interface{}{1, "%a@b.c%", "%a@b.c%"},
},
}
for _, tt := range testCases {
t.Run("", func(t *testing.T) {
sql, params, _ := hostSearchLike(tt.inSQL, tt.inParams, tt.match, tt.columns...)
assert.Equal(t, tt.outSQL, sql)
assert.Equal(t, tt.outParams, params)
})
}
}
func mockDatastore(t *testing.T) (sqlmock.Sqlmock, *Datastore) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
dbmock := sqlx.NewDb(db, "sqlmock")
ds := &Datastore{
primary: dbmock,
replica: dbmock,
logger: log.NewNopLogger(),
}
return mock, ds
}
func TestWithRetryTxxSuccess(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
require.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxRollbackSuccess(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("fail"))
mock.ExpectRollback()
require.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxRollbackError(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("fail"))
mock.ExpectRollback().WillReturnError(errors.New("rollback failed"))
require.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxRetrySuccess(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
// Return a retryable error
mock.ExpectExec("SELECT 1").WillReturnError(&mysql.MySQLError{Number: mysqlerr.ER_LOCK_DEADLOCK})
mock.ExpectRollback()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
assert.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxCommitRetrySuccess(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
// Return a retryable error
mock.ExpectCommit().WillReturnError(&mysql.MySQLError{Number: mysqlerr.ER_LOCK_DEADLOCK})
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
assert.NoError(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxxCommitError(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnResult(sqlmock.NewResult(1, 1))
// Return a retryable error
mock.ExpectCommit().WillReturnError(errors.New("fail"))
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestAppendListOptionsToSQL(t *testing.T) {
sql := "SELECT * FROM my_table"
opts := fleet.ListOptions{
OrderKey: "***name***",
}
actual, _ := appendListOptionsToSQL(sql, &opts)
expected := "SELECT * FROM my_table ORDER BY `name` ASC LIMIT 1000000"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
sql = "SELECT * FROM my_table"
opts.OrderDirection = fleet.OrderDescending
actual, _ = appendListOptionsToSQL(sql, &opts)
expected = "SELECT * FROM my_table ORDER BY `name` DESC LIMIT 1000000"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
opts = fleet.ListOptions{
PerPage: 10,
}
sql = "SELECT * FROM my_table"
actual, _ = appendListOptionsToSQL(sql, &opts)
expected = "SELECT * FROM my_table LIMIT 10"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
sql = "SELECT * FROM my_table"
opts.Page = 2
actual, _ = appendListOptionsToSQL(sql, &opts)
expected = "SELECT * FROM my_table LIMIT 10 OFFSET 20"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
opts = fleet.ListOptions{}
sql = "SELECT * FROM my_table"
actual, _ = appendListOptionsToSQL(sql, &opts)
expected = "SELECT * FROM my_table LIMIT 1000000"
if actual != expected {
t.Error("Expected", expected, "Actual", actual)
}
}
func TestWhereFilterHostsByTeams(t *testing.T) {
t.Parallel()
testCases := []struct {
filter fleet.TeamFilter
expected string
}{
// No teams or global role
{
filter: fleet.TeamFilter{
User: &fleet.User{},
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{Teams: []fleet.UserTeam{}},
},
expected: "FALSE",
},
// Global role
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
},
expected: "TRUE",
},
// Team roles
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
},
},
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
},
},
IncludeObserver: true,
},
expected: "hosts.team_id IN (1)",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 2}},
},
},
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
},
expected: "hosts.team_id IN (2)",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
IncludeObserver: true,
},
expected: "hosts.team_id IN (1,2)",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
// Invalid role should be ignored
{Role: "bad", Team: fleet.Team{ID: 37}},
},
},
},
expected: "hosts.team_id IN (2)",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
{Role: fleet.RoleAdmin, Team: fleet.Team{ID: 3}},
// Invalid role should be ignored
},
},
},
expected: "hosts.team_id IN (2,3)",
},
{
filter: fleet.TeamFilter{
TeamID: ptr.Uint(1),
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
TeamID: ptr.Uint(1),
},
expected: "hosts.team_id = 1",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: false,
TeamID: ptr.Uint(1),
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
IncludeObserver: false,
TeamID: ptr.Uint(1),
},
expected: "hosts.team_id = 1",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
TeamID: ptr.Uint(3),
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{
Teams: []fleet.UserTeam{
{Role: fleet.RoleObserver, Team: fleet.Team{ID: 1}},
{Role: fleet.RoleMaintainer, Team: fleet.Team{ID: 2}},
},
},
TeamID: ptr.Uint(2),
},
expected: "hosts.team_id = 2",
},
}
for _, tt := range testCases {
tt := tt
t.Run("", func(t *testing.T) {
t.Parallel()
ds := &Datastore{logger: log.NewNopLogger()}
sql := ds.whereFilterHostsByTeams(tt.filter, "hosts")
assert.Equal(t, tt.expected, sql)
})
}
}
func TestWhereOmitIDs(t *testing.T) {
t.Parallel()
testCases := []struct {
omits []uint
expected string
}{
{
omits: nil,
expected: "TRUE",
},
{
omits: []uint{},
expected: "TRUE",
},
{
omits: []uint{1, 3, 4},
expected: "id NOT IN (1,3,4)",
},
{
omits: []uint{42},
expected: "id NOT IN (42)",
},
}
for _, tt := range testCases {
tt := tt
t.Run("", func(t *testing.T) {
t.Parallel()
ds := &Datastore{logger: log.NewNopLogger()}
sql := ds.whereOmitIDs("id", tt.omits)
assert.Equal(t, tt.expected, sql)
})
}
}
func TestWithRetryTxWithRollback(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithRetryTxWillRollbackWhenPanic(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
defer func() { recover() }() //nolint:errcheck
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withRetryTxx(context.Background(), func(tx sqlx.ExtContext) error {
panic("ROLLBACK")
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithTxWithRollback(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withTx(context.Background(), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "SELECT 1")
return err
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestWithTxWillRollbackWhenPanic(t *testing.T) {
mock, ds := mockDatastore(t)
defer ds.Close()
defer func() { recover() }() //nolint:errcheck
mock.ExpectBegin()
mock.ExpectExec("SELECT 1").WillReturnError(errors.New("let's rollback!"))
mock.ExpectRollback()
assert.Error(t, ds.withTx(context.Background(), func(tx sqlx.ExtContext) error {
panic("ROLLBACK")
}))
require.NoError(t, mock.ExpectationsWereMet())
}
func TestNewReadsPasswordFromDisk(t *testing.T) {
passwordFile, err := os.CreateTemp(t.TempDir(), "*.passwordtest")
require.NoError(t, err)
_, err = passwordFile.WriteString(testPassword)
require.NoError(t, err)
passwordPath := passwordFile.Name()
require.NoError(t, passwordFile.Close())
dbName := t.Name()
// Create a datastore client in order to run migrations as usual
mysqlConfig := config.MysqlConfig{
Username: testUsername,
Password: "",
PasswordPath: passwordPath,
Address: testAddress,
Database: dbName,
}
ds, err := newDSWithConfig(t, dbName, mysqlConfig)
require.NoError(t, err)
defer ds.Close()
require.NoError(t, ds.HealthCheck())
}
func newDSWithConfig(t *testing.T, dbName string, config config.MysqlConfig) (*Datastore, error) {
db, err := sql.Open(
"mysql",
fmt.Sprintf("%s:%s@tcp(%s)/?multiStatements=true", testUsername, testPassword, testAddress),
)
require.NoError(t, err)
_, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s; CREATE DATABASE %s;", dbName, dbName))
require.NoError(t, err)
ds, err := New(config, clock.NewMockClock(), Logger(log.NewNopLogger()), LimitAttempts(1))
return ds, err
}
func generateTestCert(t *testing.T) (string, string) {
privateKeyCA, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"aa"},
},
NotBefore: time.Now().Add(-1 * time.Duration(24) * time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
IsCA: true,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKeyCA.PublicKey, privateKeyCA)
require.NoError(t, err)
publicPem, err := os.CreateTemp(t.TempDir(), "*-ca.pem")
require.NoError(t, err)
require.NoError(t, pem.Encode(publicPem, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
require.NoError(t, publicPem.Close())
keyPem, err := os.CreateTemp(t.TempDir(), "*-key.pem")
require.NoError(t, err)
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKeyCA)
require.NoError(t, pem.Encode(keyPem, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes}))
require.NoError(t, keyPem.Close())
return publicPem.Name(), keyPem.Name()
}
func TestNewUsesRegisterTLS(t *testing.T) {
dbName := t.Name()
ca, _ := generateTestCert(t)
cert, key := generateTestCert(t)
mysqlConfig := config.MysqlConfig{
Username: testUsername,
Password: testPassword,
Address: testAddress,
Database: dbName,
TLSCA: ca,
TLSCert: cert,
TLSKey: key,
}
// This fails because the certificate mysql is using is different than the one generated here
_, err := newDSWithConfig(t, dbName, mysqlConfig)
require.Error(t, err)
// TODO: we're using a Regexp because the message is different depending on the version of mysql,
// we should refactor and use different error types instead.
require.Regexp(t, "^(x509|tls)", err.Error())
}
func TestWhereFilterTeams(t *testing.T) {
t.Parallel()
testCases := []struct {
filter fleet.TeamFilter
expected string
}{
// No teams or global role
{
filter: fleet.TeamFilter{User: nil},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: false,
},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
IncludeObserver: true,
},
expected: "TRUE",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}}},
expected: "t.id IN (1)",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}}},
expected: "t.id IN (1)",
},
{
filter: fleet.TeamFilter{User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}}},
expected: "FALSE",
},
{
filter: fleet.TeamFilter{
User: &fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
IncludeObserver: true,
},
expected: "t.id IN (1)",
},
}
for _, tt := range testCases {
tt := tt
t.Run("", func(t *testing.T) {
t.Parallel()
ds := &Datastore{logger: log.NewNopLogger()}
sql := ds.whereFilterTeams(tt.filter, "t")
assert.Equal(t, tt.expected, sql)
})
}
}
func TestCompareVersions(t *testing.T) {
for _, tc := range []struct {
name string
v1 []int64
v2 []int64
knownUnknowns map[int64]struct{}
expMissing []int64
expUnknown []int64
expEqual bool
}{
{
name: "both-empty",
v1: nil,
v2: nil,
expEqual: true,
},
{
name: "equal",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 3},
expEqual: true,
},
{
name: "equal-out-of-order",
v1: []int64{1, 2, 3},
v2: []int64{1, 3, 2},
expEqual: true,
},
{
name: "empty-with-unknown",
v1: nil,
v2: []int64{1},
expEqual: false,
expUnknown: []int64{1},
},
{
name: "empty-with-missing",
v1: []int64{1},
v2: nil,
expEqual: false,
expMissing: []int64{1},
},
{
name: "missing",
v1: []int64{1, 2, 3},
v2: []int64{1, 3},
expMissing: []int64{2},
expEqual: false,
},
{
name: "unknown",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 3, 4},
expUnknown: []int64{4},
expEqual: false,
},
{
name: "known-unknown",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 3, 4},
knownUnknowns: map[int64]struct{}{
4: {},
},
expEqual: true,
},
{
name: "unknowns",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 3, 4, 5},
expUnknown: []int64{5},
knownUnknowns: map[int64]struct{}{
4: {},
},
expEqual: false,
},
{
name: "missing-and-unknown",
v1: []int64{1, 2, 3},
v2: []int64{1, 2, 4},
expMissing: []int64{3},
expUnknown: []int64{4},
expEqual: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
missing, unknown, equal := compareVersions(tc.v1, tc.v2, tc.knownUnknowns)
require.Equal(t, tc.expMissing, missing)
require.Equal(t, tc.expUnknown, unknown)
require.Equal(t, tc.expEqual, equal)
})
}
}
func TestDebugs(t *testing.T) {
ds := CreateMySQLDS(t)
status, err := ds.InnoDBStatus(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, status)
processList, err := ds.ProcessList(context.Background())
require.NoError(t, err)
require.Greater(t, len(processList), 0)
}
func TestWantedModesEnabled(t *testing.T) {
ds := CreateMySQLDS(t)
var sqlMode string
err := ds.writer(context.Background()).GetContext(context.Background(), &sqlMode, `SELECT @@SQL_MODE`)
require.NoError(t, err)
require.Contains(t, sqlMode, "ANSI_QUOTES")
require.Contains(t, sqlMode, "ONLY_FULL_GROUP_BY")
}
func Test_buildWildcardMatchPhrase(t *testing.T) {
type args struct {
matchQuery string
}
tests := []struct {
name string
args args
want string
}{
{
name: "",
args: args{matchQuery: "test"},
want: "%test%",
},
{
name: "underscores are escaped",
args: args{matchQuery: "Host_1"},
want: "%Host\\_1%",
},
{
name: "percent are escaped",
args: args{matchQuery: "Host%1"},
want: "%Host\\%1%",
},
{
name: "percent & underscore are escaped",
args: args{matchQuery: "Host_%1"},
want: "%Host\\_\\%1%",
},
{
name: "underscores added for wildcard search are not escaped",
args: args{matchQuery: "Alices MacbookPro"},
want: "%Alice_s MacbookPro%",
},
{
name: "underscores added for wildcard search are not escaped, but underscores in matchQuery are",
args: args{matchQuery: "Alices Macbook_Pro"},
want: "%Alice_s Macbook\\_Pro%",
},
{
name: "multiple occurances of wildcard are not escaped",
args: args{matchQuery: "Alices Macbook_Pro"},
want: "%Alice__s Macbook\\_Pro%",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, buildWildcardMatchPhrase(tt.args.matchQuery), "buildWildcardMatchPhrase(%v)", tt.args.matchQuery)
})
}
}