Prepare LoadHostByNodeKey query once (#4128)

* Prepare LoadHostByNodeKey query once

* Use a protected map for storing statements

* Add proposed test
This commit is contained in:
Lucas Manuel Rodriguez 2022-02-14 12:13:38 -03:00 committed by GitHub
parent 0ed0c4e27f
commit 9c25ea1641
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 127 additions and 6 deletions

1
go.mod
View File

@ -39,6 +39,7 @@ require (
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2
github.com/gosuri/uilive v0.0.4
github.com/hashicorp/go-multierror v1.0.0 // indirect
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95
github.com/igm/sockjs-go/v3 v3.0.0
github.com/jinzhu/copier v0.3.2

View File

@ -709,12 +709,25 @@ func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey stri
return &host, nil
}
// GetContextTryStmt will attempt to run sqlx.GetContext on a cached statement if available, resorting to ds.reader.
func (ds *Datastore) GetContextTryStmt(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
var err error
//nolint the statements are closed in Datastore.Close.
if stmt := ds.loadOrPrepareStmt(ctx, query); stmt != nil {
err = stmt.GetContext(ctx, dest, args...)
} else {
err = sqlx.GetContext(ctx, ds.reader, dest, query, args...)
}
return err
}
// LoadHostByNodeKey loads the whole host identified by the node key.
// If the node key is invalid it returns a NotFoundError.
func (ds *Datastore) LoadHostByNodeKey(ctx context.Context, nodeKey string) (*fleet.Host, error) {
sqlStatement := `SELECT * FROM hosts WHERE node_key = ?`
query := `SELECT * FROM hosts WHERE node_key = ?`
var host fleet.Host
switch err := sqlx.GetContext(ctx, ds.reader, &host, sqlStatement, nodeKey); {
switch err := ds.GetContextTryStmt(ctx, &host, query, nodeKey); {
case err == nil:
return &host, nil
case errors.Is(err, sql.ErrNoRows):
@ -1264,6 +1277,7 @@ func (ds *Datastore) GetMDM(ctx context.Context, hostID uint) (bool, string, boo
}
return dest.Enrolled, dest.ServerURL, dest.InstalledFromDep, nil
}
func (ds *Datastore) AggregatedMunkiVersion(ctx context.Context, teamID *uint) ([]fleet.AggregatedMunkiVersion, time.Time, error) {
id := uint(0)

View File

@ -95,6 +95,7 @@ func TestHosts(t *testing.T) {
{"SaveTonsOfUsers", testHostsSaveTonsOfUsers},
{"SavePackStatsConcurrent", testHostsSavePackStatsConcurrent},
{"LoadHostByNodeKeyLoadsDisk", testLoadHostByNodeKeyLoadsDisk},
{"LoadHostByNodeKeyUsesStmt", testLoadHostByNodeKeyUsesStmt},
{"HostsListBySoftware", testHostsListBySoftware},
{"HostsListFailingPolicies", printReadsInTest(testHostsListFailingPolicies)},
{"HostsExpiration", testHostsExpiration},
@ -1421,6 +1422,58 @@ func testLoadHostByNodeKeyLoadsDisk(t *testing.T, ds *Datastore) {
assert.NotZero(t, h.PercentDiskSpaceAvailable)
}
func testLoadHostByNodeKeyUsesStmt(t *testing.T, ds *Datastore) {
_, err := ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
OsqueryHostID: "foobar",
NodeKey: "nodekey",
UUID: "uuid",
Hostname: "foobar.local",
})
require.NoError(t, err)
_, err = ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
OsqueryHostID: "foobar2",
NodeKey: "nodekey2",
UUID: "uuid2",
Hostname: "foobar2.local",
})
require.NoError(t, err)
err = ds.closeStmts()
require.NoError(t, err)
ds.stmtCacheMu.Lock()
require.Len(t, ds.stmtCache, 0)
ds.stmtCacheMu.Unlock()
h, err := ds.LoadHostByNodeKey(context.Background(), "nodekey")
require.NoError(t, err)
require.Equal(t, "foobar.local", h.Hostname)
ds.stmtCacheMu.Lock()
require.Len(t, ds.stmtCache, 1)
ds.stmtCacheMu.Unlock()
h, err = ds.LoadHostByNodeKey(context.Background(), "nodekey")
require.NoError(t, err)
require.Equal(t, "foobar.local", h.Hostname)
ds.stmtCacheMu.Lock()
require.Len(t, ds.stmtCache, 1)
ds.stmtCacheMu.Unlock()
h, err = ds.LoadHostByNodeKey(context.Background(), "nodekey2")
require.NoError(t, err)
require.Equal(t, "foobar2.local", h.Hostname)
}
func testHostsAdditional(t *testing.T, ds *Datastore) {
h, err := ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),

View File

@ -11,6 +11,7 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/VividCortex/mysqlerr"
@ -27,6 +28,7 @@ import (
"github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/go-sql-driver/mysql"
"github.com/hashicorp/go-multierror"
"github.com/jmoiron/sqlx"
"github.com/ngrok/sqlmw"
)
@ -42,6 +44,7 @@ var columnCharsRegexp = regexp.MustCompile(`[^\w-.]`)
// dbReader is an interface that defines the methods required for reads.
type dbReader interface {
sqlx.QueryerContext
sqlx.PreparerContext
Close() error
Rebind(string) string
@ -61,6 +64,36 @@ type Datastore struct {
readReplicaConfig *config.MysqlConfig
writeCh chan itemToWrite
// stmtCacheMu protects access to stmtCache.
stmtCacheMu sync.Mutex
// stmtCache holds statements for queries.
stmtCache map[string]*sqlx.Stmt
}
// loadOrPrepareStmt will load a statement from the statements cache.
// If not available, it will attempt to prepare (create) it.
//
// Returns nil if it failed to prepare a statement.
func (ds *Datastore) loadOrPrepareStmt(ctx context.Context, query string) *sqlx.Stmt {
ds.stmtCacheMu.Lock()
defer ds.stmtCacheMu.Unlock()
stmt, ok := ds.stmtCache[query]
if !ok {
var err error
stmt, err = sqlx.PreparexContext(ctx, ds.reader, query)
if err != nil {
level.Error(ds.logger).Log(
"msg", "failed to prepare statement",
"query", query,
"err", err,
)
return nil
}
ds.stmtCache[query] = stmt
}
return stmt
}
type txFn func(sqlx.ExtContext) error
@ -217,6 +250,7 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
config: config,
readReplicaConfig: options.replicaConfig,
writeCh: make(chan itemToWrite),
stmtCache: make(map[string]*sqlx.Stmt),
}
go ds.writeChanLoop()
@ -482,13 +516,32 @@ func (ds *Datastore) HealthCheck() error {
return nil
}
func (ds *Datastore) closeStmts() error {
ds.stmtCacheMu.Lock()
defer ds.stmtCacheMu.Unlock()
var err error
for query, stmt := range ds.stmtCache {
if errClose := stmt.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
delete(ds.stmtCache, query)
}
return err
}
// Close frees resources associated with underlying mysql connection
func (ds *Datastore) Close() error {
err := ds.writer.Close()
var err error
if errStmt := ds.closeStmts(); errStmt != nil {
err = multierror.Append(err, errStmt)
}
if errWriter := ds.writer.Close(); errWriter != nil {
err = multierror.Append(err, errWriter)
}
if ds.readReplicaConfig != nil {
errRead := ds.reader.Close()
if err == nil {
err = errRead
if errRead := ds.reader.Close(); errRead != nil {
err = multierror.Append(err, errRead)
}
}
return err