mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Prepare LoadHostByNodeKey
query once (#4128)
* Prepare LoadHostByNodeKey query once * Use a protected map for storing statements * Add proposed test
This commit is contained in:
parent
0ed0c4e27f
commit
9c25ea1641
1
go.mod
1
go.mod
@ -39,6 +39,7 @@ require (
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/gosuri/uilive v0.0.4
|
||||
github.com/hashicorp/go-multierror v1.0.0 // indirect
|
||||
github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95
|
||||
github.com/igm/sockjs-go/v3 v3.0.0
|
||||
github.com/jinzhu/copier v0.3.2
|
||||
|
@ -709,12 +709,25 @@ func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey stri
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
// GetContextTryStmt will attempt to run sqlx.GetContext on a cached statement if available, resorting to ds.reader.
|
||||
func (ds *Datastore) GetContextTryStmt(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
var err error
|
||||
//nolint the statements are closed in Datastore.Close.
|
||||
if stmt := ds.loadOrPrepareStmt(ctx, query); stmt != nil {
|
||||
err = stmt.GetContext(ctx, dest, args...)
|
||||
} else {
|
||||
err = sqlx.GetContext(ctx, ds.reader, dest, query, args...)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadHostByNodeKey loads the whole host identified by the node key.
|
||||
// If the node key is invalid it returns a NotFoundError.
|
||||
func (ds *Datastore) LoadHostByNodeKey(ctx context.Context, nodeKey string) (*fleet.Host, error) {
|
||||
sqlStatement := `SELECT * FROM hosts WHERE node_key = ?`
|
||||
query := `SELECT * FROM hosts WHERE node_key = ?`
|
||||
|
||||
var host fleet.Host
|
||||
switch err := sqlx.GetContext(ctx, ds.reader, &host, sqlStatement, nodeKey); {
|
||||
switch err := ds.GetContextTryStmt(ctx, &host, query, nodeKey); {
|
||||
case err == nil:
|
||||
return &host, nil
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
@ -1264,6 +1277,7 @@ func (ds *Datastore) GetMDM(ctx context.Context, hostID uint) (bool, string, boo
|
||||
}
|
||||
return dest.Enrolled, dest.ServerURL, dest.InstalledFromDep, nil
|
||||
}
|
||||
|
||||
func (ds *Datastore) AggregatedMunkiVersion(ctx context.Context, teamID *uint) ([]fleet.AggregatedMunkiVersion, time.Time, error) {
|
||||
id := uint(0)
|
||||
|
||||
|
@ -95,6 +95,7 @@ func TestHosts(t *testing.T) {
|
||||
{"SaveTonsOfUsers", testHostsSaveTonsOfUsers},
|
||||
{"SavePackStatsConcurrent", testHostsSavePackStatsConcurrent},
|
||||
{"LoadHostByNodeKeyLoadsDisk", testLoadHostByNodeKeyLoadsDisk},
|
||||
{"LoadHostByNodeKeyUsesStmt", testLoadHostByNodeKeyUsesStmt},
|
||||
{"HostsListBySoftware", testHostsListBySoftware},
|
||||
{"HostsListFailingPolicies", printReadsInTest(testHostsListFailingPolicies)},
|
||||
{"HostsExpiration", testHostsExpiration},
|
||||
@ -1421,6 +1422,58 @@ func testLoadHostByNodeKeyLoadsDisk(t *testing.T, ds *Datastore) {
|
||||
assert.NotZero(t, h.PercentDiskSpaceAvailable)
|
||||
}
|
||||
|
||||
func testLoadHostByNodeKeyUsesStmt(t *testing.T, ds *Datastore) {
|
||||
_, err := ds.NewHost(context.Background(), &fleet.Host{
|
||||
DetailUpdatedAt: time.Now(),
|
||||
LabelUpdatedAt: time.Now(),
|
||||
PolicyUpdatedAt: time.Now(),
|
||||
SeenTime: time.Now(),
|
||||
OsqueryHostID: "foobar",
|
||||
NodeKey: "nodekey",
|
||||
UUID: "uuid",
|
||||
Hostname: "foobar.local",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ds.NewHost(context.Background(), &fleet.Host{
|
||||
DetailUpdatedAt: time.Now(),
|
||||
LabelUpdatedAt: time.Now(),
|
||||
PolicyUpdatedAt: time.Now(),
|
||||
SeenTime: time.Now(),
|
||||
OsqueryHostID: "foobar2",
|
||||
NodeKey: "nodekey2",
|
||||
UUID: "uuid2",
|
||||
Hostname: "foobar2.local",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.closeStmts()
|
||||
require.NoError(t, err)
|
||||
|
||||
ds.stmtCacheMu.Lock()
|
||||
require.Len(t, ds.stmtCache, 0)
|
||||
ds.stmtCacheMu.Unlock()
|
||||
|
||||
h, err := ds.LoadHostByNodeKey(context.Background(), "nodekey")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foobar.local", h.Hostname)
|
||||
|
||||
ds.stmtCacheMu.Lock()
|
||||
require.Len(t, ds.stmtCache, 1)
|
||||
ds.stmtCacheMu.Unlock()
|
||||
|
||||
h, err = ds.LoadHostByNodeKey(context.Background(), "nodekey")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foobar.local", h.Hostname)
|
||||
|
||||
ds.stmtCacheMu.Lock()
|
||||
require.Len(t, ds.stmtCache, 1)
|
||||
ds.stmtCacheMu.Unlock()
|
||||
|
||||
h, err = ds.LoadHostByNodeKey(context.Background(), "nodekey2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foobar2.local", h.Hostname)
|
||||
}
|
||||
|
||||
func testHostsAdditional(t *testing.T, ds *Datastore) {
|
||||
h, err := ds.NewHost(context.Background(), &fleet.Host{
|
||||
DetailUpdatedAt: time.Now(),
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/VividCortex/mysqlerr"
|
||||
@ -27,6 +28,7 @@ import (
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/ngrok/sqlmw"
|
||||
)
|
||||
@ -42,6 +44,7 @@ var columnCharsRegexp = regexp.MustCompile(`[^\w-.]`)
|
||||
// dbReader is an interface that defines the methods required for reads.
|
||||
type dbReader interface {
|
||||
sqlx.QueryerContext
|
||||
sqlx.PreparerContext
|
||||
|
||||
Close() error
|
||||
Rebind(string) string
|
||||
@ -61,6 +64,36 @@ type Datastore struct {
|
||||
readReplicaConfig *config.MysqlConfig
|
||||
|
||||
writeCh chan itemToWrite
|
||||
|
||||
// stmtCacheMu protects access to stmtCache.
|
||||
stmtCacheMu sync.Mutex
|
||||
// stmtCache holds statements for queries.
|
||||
stmtCache map[string]*sqlx.Stmt
|
||||
}
|
||||
|
||||
// loadOrPrepareStmt will load a statement from the statements cache.
|
||||
// If not available, it will attempt to prepare (create) it.
|
||||
//
|
||||
// Returns nil if it failed to prepare a statement.
|
||||
func (ds *Datastore) loadOrPrepareStmt(ctx context.Context, query string) *sqlx.Stmt {
|
||||
ds.stmtCacheMu.Lock()
|
||||
defer ds.stmtCacheMu.Unlock()
|
||||
|
||||
stmt, ok := ds.stmtCache[query]
|
||||
if !ok {
|
||||
var err error
|
||||
stmt, err = sqlx.PreparexContext(ctx, ds.reader, query)
|
||||
if err != nil {
|
||||
level.Error(ds.logger).Log(
|
||||
"msg", "failed to prepare statement",
|
||||
"query", query,
|
||||
"err", err,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
ds.stmtCache[query] = stmt
|
||||
}
|
||||
return stmt
|
||||
}
|
||||
|
||||
type txFn func(sqlx.ExtContext) error
|
||||
@ -217,6 +250,7 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore
|
||||
config: config,
|
||||
readReplicaConfig: options.replicaConfig,
|
||||
writeCh: make(chan itemToWrite),
|
||||
stmtCache: make(map[string]*sqlx.Stmt),
|
||||
}
|
||||
|
||||
go ds.writeChanLoop()
|
||||
@ -482,13 +516,32 @@ func (ds *Datastore) HealthCheck() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ds *Datastore) closeStmts() error {
|
||||
ds.stmtCacheMu.Lock()
|
||||
defer ds.stmtCacheMu.Unlock()
|
||||
|
||||
var err error
|
||||
for query, stmt := range ds.stmtCache {
|
||||
if errClose := stmt.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
delete(ds.stmtCache, query)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close frees resources associated with underlying mysql connection
|
||||
func (ds *Datastore) Close() error {
|
||||
err := ds.writer.Close()
|
||||
var err error
|
||||
if errStmt := ds.closeStmts(); errStmt != nil {
|
||||
err = multierror.Append(err, errStmt)
|
||||
}
|
||||
if errWriter := ds.writer.Close(); errWriter != nil {
|
||||
err = multierror.Append(err, errWriter)
|
||||
}
|
||||
if ds.readReplicaConfig != nil {
|
||||
errRead := ds.reader.Close()
|
||||
if err == nil {
|
||||
err = errRead
|
||||
if errRead := ds.reader.Close(); errRead != nil {
|
||||
err = multierror.Append(err, errRead)
|
||||
}
|
||||
}
|
||||
return err
|
||||
|
Loading…
Reference in New Issue
Block a user