Load all host values from the db to prevent overwriting columns with empty vals (#2394)

This commit is contained in:
Tomas Touceda 2021-10-06 09:27:53 -03:00 committed by GitHub
parent 87b3563db7
commit 1af3803cdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 108 deletions

View File

@ -0,0 +1 @@
* Load all host values from the database to prevent overwriting them with empty values in certain cases.

View File

@ -560,49 +560,7 @@ func (d *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey strin
func (d *Datastore) AuthenticateHost(ctx context.Context, nodeKey string) (*fleet.Host, error) {
// Select everything besides `additional`
sqlStatement := `
SELECT
id,
osquery_host_id,
created_at,
updated_at,
detail_updated_at,
label_updated_at,
policy_updated_at,
node_key,
hostname,
uuid,
platform,
osquery_version,
os_version,
build,
platform_like,
code_name,
uptime,
memory,
cpu_type,
cpu_subtype,
cpu_brand,
cpu_physical_cores,
cpu_logical_cores,
hardware_vendor,
hardware_model,
hardware_version,
hardware_serial,
computer_name,
primary_ip_id,
seen_time,
distributed_interval,
logger_tls_period,
config_tls_refresh,
primary_ip,
primary_mac,
refetch_requested,
team_id
FROM hosts
WHERE node_key = ?
LIMIT 1
`
sqlStatement := `SELECT * FROM hosts WHERE node_key = ? LIMIT 1`
host := &fleet.Host{}
if err := sqlx.GetContext(ctx, d.reader, host, sqlStatement, nodeKey); err != nil {

View File

@ -81,6 +81,7 @@ func TestHosts(t *testing.T) {
{"ListByPolicy", testHostsListByPolicy},
{"SaveTonsOfUsers", testHostsSaveTonsOfUsers},
{"SavePackStatsConcurrent", testHostsSavePackStatsConcurrent},
{"AuthenticateHostLoadsDisk", testAuthenticateHostLoadsDisk},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -109,10 +110,10 @@ func testHostsSave(t *testing.T, ds *Datastore) {
host.Hostname = "bar.local"
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "bar.local", host.Hostname)
assert.Equal(t, "192.168.1.1", host.PrimaryIP)
assert.Equal(t, "30-65-EC-6F-C4-58", host.PrimaryMac)
@ -125,16 +126,16 @@ func testHostsSave(t *testing.T, ds *Datastore) {
require.NoError(t, saveHostAdditionalDB(context.Background(), ds.writer, host))
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
require.NotNil(t, host)
require.NotNil(t, host.Additional)
assert.Equal(t, additionalJSON, *host.Additional)
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
require.NotNil(t, host)
err = ds.DeleteHost(context.Background(), host.ID)
@ -526,7 +527,7 @@ func testHostsDelete(t *testing.T, ds *Datastore) {
UUID: "1",
Hostname: "foo.local",
})
require.Nil(t, err)
require.NoError(t, err)
require.NotNil(t, host)
err = ds.DeleteHost(context.Background(), host.ID)
@ -547,7 +548,7 @@ func testHostsListFilterAdditional(t *testing.T, ds *Datastore) {
UUID: "uuid",
Hostname: "foobar.local",
})
require.Nil(t, err)
require.NoError(t, err)
filter := fleet.TeamFilter{User: test.UserAdmin}
@ -707,14 +708,14 @@ func testHostsEnroll(t *testing.T, ds *Datastore) {
filter := fleet.TeamFilter{User: test.UserAdmin}
hosts, err := ds.ListHosts(context.Background(), filter, fleet.HostListOptions{})
require.Nil(t, err)
require.NoError(t, err)
for _, host := range hosts {
assert.Zero(t, host.LastEnrolledAt)
}
for _, tt := range enrollTests {
h, err := ds.EnrollHost(context.Background(), tt.uuid, tt.nodeKey, &team.ID, 0)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, tt.uuid, h.OsqueryHostID)
assert.Equal(t, tt.nodeKey, h.NodeKey)
@ -730,7 +731,7 @@ func testHostsEnroll(t *testing.T, ds *Datastore) {
hosts, err = ds.ListHosts(context.Background(), filter, fleet.HostListOptions{})
require.Nil(t, err)
require.NoError(t, err)
for _, host := range hosts {
assert.NotZero(t, host.LastEnrolledAt)
}
@ -740,7 +741,7 @@ func testHostsAuthenticate(t *testing.T, ds *Datastore) {
test.AddAllHostsLabel(t, ds)
for _, tt := range enrollTests {
h, err := ds.EnrollHost(context.Background(), tt.uuid, tt.nodeKey, nil, 0)
require.Nil(t, err)
require.NoError(t, err)
returned, err := ds.AuthenticateHost(context.Background(), h.NodeKey)
require.NoError(t, err)
@ -758,7 +759,7 @@ func testHostsAuthenticateCaseSensitive(t *testing.T, ds *Datastore) {
test.AddAllHostsLabel(t, ds)
for _, tt := range enrollTests {
h, err := ds.EnrollHost(context.Background(), tt.uuid, tt.nodeKey, nil, 0)
require.Nil(t, err)
require.NoError(t, err)
_, err = ds.AuthenticateHost(context.Background(), strings.ToUpper(h.NodeKey))
require.Error(t, err, "node key authentication should be case sensitive")
@ -776,7 +777,7 @@ func testHostsSearch(t *testing.T, ds *Datastore) {
UUID: "1",
Hostname: "foo.local",
})
require.Nil(t, err)
require.NoError(t, err)
h2, err := ds.NewHost(context.Background(), &fleet.Host{
OsqueryHostID: "5679",
@ -788,7 +789,7 @@ func testHostsSearch(t *testing.T, ds *Datastore) {
UUID: "2",
Hostname: "bar.local",
})
require.Nil(t, err)
require.NoError(t, err)
h3, err := ds.NewHost(context.Background(), &fleet.Host{
OsqueryHostID: "99999",
@ -800,7 +801,7 @@ func testHostsSearch(t *testing.T, ds *Datastore) {
UUID: "abc-def-ghi",
Hostname: "foo-bar.local",
})
require.Nil(t, err)
require.NoError(t, err)
user := &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}
filter := fleet.TeamFilter{User: user}
@ -808,24 +809,24 @@ func testHostsSearch(t *testing.T, ds *Datastore) {
// We once threw errors when the search query was empty. Verify that we
// don't error.
_, err = ds.SearchHosts(context.Background(), filter, "")
require.Nil(t, err)
require.NoError(t, err)
hosts, err := ds.SearchHosts(context.Background(), filter, "foo")
assert.Nil(t, err)
assert.Len(t, hosts, 2)
host, err := ds.SearchHosts(context.Background(), filter, "foo", h3.ID)
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, host, 1)
assert.Equal(t, "foo.local", host[0].Hostname)
host, err = ds.SearchHosts(context.Background(), filter, "foo", h3.ID, h2.ID)
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, host, 1)
assert.Equal(t, "foo.local", host[0].Hostname)
host, err = ds.SearchHosts(context.Background(), filter, "abc")
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, host, 1)
assert.Equal(t, "abc-def-ghi", host[0].UUID)
@ -836,24 +837,24 @@ func testHostsSearch(t *testing.T, ds *Datastore) {
// check to make sure search on ip address works
h2.PrimaryIP = "99.100.101.103"
err = ds.SaveHost(context.Background(), h2)
require.Nil(t, err)
require.NoError(t, err)
hits, err := ds.SearchHosts(context.Background(), filter, "99.100.101")
require.Nil(t, err)
require.NoError(t, err)
require.Equal(t, 1, len(hits))
hits, err = ds.SearchHosts(context.Background(), filter, "99.100.111")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, 0, len(hits))
h3.PrimaryIP = "99.100.101.104"
err = ds.SaveHost(context.Background(), h3)
require.Nil(t, err)
require.NoError(t, err)
hits, err = ds.SearchHosts(context.Background(), filter, "99.100.101")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, 2, len(hits))
hits, err = ds.SearchHosts(context.Background(), filter, "99.100.101", h3.ID)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, 1, len(hits))
}
@ -871,11 +872,11 @@ func testHostsSearchLimit(t *testing.T, ds *Datastore) {
UUID: fmt.Sprintf("%d", i),
Hostname: fmt.Sprintf("foo.%d.local", i),
})
require.Nil(t, err)
require.NoError(t, err)
}
hosts, err := ds.SearchHosts(context.Background(), filter, "foo")
require.Nil(t, err)
require.NoError(t, err)
assert.Len(t, hosts, 10)
}
@ -900,7 +901,7 @@ func testHostsGenerateStatusStatistics(t *testing.T, ds *Datastore) {
PolicyUpdatedAt: mockClock.Now().Add(-30 * time.Second),
SeenTime: mockClock.Now().Add(-30 * time.Second),
})
require.Nil(t, err)
require.NoError(t, err)
h.DistributedInterval = 15
h.ConfigTLSRefresh = 30
require.Nil(t, ds.SaveHost(context.Background(), h))
@ -915,7 +916,7 @@ func testHostsGenerateStatusStatistics(t *testing.T, ds *Datastore) {
PolicyUpdatedAt: mockClock.Now().Add(-1 * time.Minute),
SeenTime: mockClock.Now().Add(-1 * time.Minute),
})
require.Nil(t, err)
require.NoError(t, err)
h.DistributedInterval = 60
h.ConfigTLSRefresh = 3600
require.Nil(t, ds.SaveHost(context.Background(), h))
@ -930,7 +931,7 @@ func testHostsGenerateStatusStatistics(t *testing.T, ds *Datastore) {
PolicyUpdatedAt: mockClock.Now().Add(-1 * time.Hour),
SeenTime: mockClock.Now().Add(-1 * time.Hour),
})
require.Nil(t, err)
require.NoError(t, err)
h.DistributedInterval = 300
h.ConfigTLSRefresh = 300
require.Nil(t, ds.SaveHost(context.Background(), h))
@ -945,7 +946,7 @@ func testHostsGenerateStatusStatistics(t *testing.T, ds *Datastore) {
PolicyUpdatedAt: mockClock.Now().Add(-35 * (24 * time.Hour)),
SeenTime: mockClock.Now().Add(-35 * (24 * time.Hour)),
})
require.Nil(t, err)
require.NoError(t, err)
online, offline, mia, new, err = ds.GenerateHostStatusStatistics(context.Background(), filter, mockClock.Now())
assert.Nil(t, err)
@ -1015,7 +1016,7 @@ func testHostsMarkSeenMany(t *testing.T, ds *Datastore) {
PolicyUpdatedAt: aDayAgo,
SeenTime: aDayAgo,
})
require.Nil(t, err)
require.NoError(t, err)
h2, err := ds.NewHost(context.Background(), &fleet.Host{
ID: 2,
@ -1027,7 +1028,7 @@ func testHostsMarkSeenMany(t *testing.T, ds *Datastore) {
PolicyUpdatedAt: aDayAgo,
SeenTime: aDayAgo,
})
require.Nil(t, err)
require.NoError(t, err)
err = ds.MarkHostsSeen(context.Background(), []uint{h1.ID}, anHourAgo)
assert.Nil(t, err)
@ -1073,7 +1074,7 @@ func testHostsCleanupIncoming(t *testing.T, ds *Datastore) {
PolicyUpdatedAt: mockClock.Now(),
SeenTime: mockClock.Now(),
})
require.Nil(t, err)
require.NoError(t, err)
h2, err := ds.NewHost(context.Background(), &fleet.Host{
ID: 2,
@ -1087,7 +1088,7 @@ func testHostsCleanupIncoming(t *testing.T, ds *Datastore) {
PolicyUpdatedAt: mockClock.Now(),
SeenTime: mockClock.Now(),
})
require.Nil(t, err)
require.NoError(t, err)
err = ds.CleanupIncomingHosts(context.Background(), mockClock.Now().UTC())
assert.Nil(t, err)
@ -1120,16 +1121,41 @@ func testHostsIDsByName(t *testing.T, ds *Datastore) {
UUID: fmt.Sprintf("%d", i),
Hostname: fmt.Sprintf("foo.%d.local", i),
})
require.Nil(t, err)
require.NoError(t, err)
}
filter := fleet.TeamFilter{User: test.UserAdmin}
hosts, err := ds.HostIDsByName(context.Background(), filter, []string{"foo.2.local", "foo.1.local", "foo.5.local"})
require.Nil(t, err)
require.NoError(t, err)
sort.Slice(hosts, func(i, j int) bool { return hosts[i] < hosts[j] })
assert.Equal(t, hosts, []uint{2, 3, 6})
}
func testAuthenticateHostLoadsDisk(t *testing.T, ds *Datastore) {
h, err := ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
OsqueryHostID: "foobar",
NodeKey: "nodekey",
UUID: "uuid",
Hostname: "foobar.local",
})
require.NoError(t, err)
h.GigsDiskSpaceAvailable = 1.24
h.PercentDiskSpaceAvailable = 42.0
require.NoError(t, ds.SaveHost(context.Background(), h))
h, err = ds.Host(context.Background(), h.ID)
require.NoError(t, err)
h, err = ds.AuthenticateHost(context.Background(), "nodekey")
require.NoError(t, err)
assert.NotZero(t, h.GigsDiskSpaceAvailable)
assert.NotZero(t, h.PercentDiskSpaceAvailable)
}
func testHostsAdditional(t *testing.T, ds *Datastore) {
_, err := ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
@ -1141,16 +1167,16 @@ func testHostsAdditional(t *testing.T, ds *Datastore) {
UUID: "uuid",
Hostname: "foobar.local",
})
require.Nil(t, err)
require.NoError(t, err)
h, err := ds.AuthenticateHost(context.Background(), "nodekey")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "foobar.local", h.Hostname)
assert.Nil(t, h.Additional)
// Additional not yet set
h, err = ds.Host(context.Background(), h.ID)
require.Nil(t, err)
require.NoError(t, err)
assert.Nil(t, h.Additional)
// Add additional
@ -1160,45 +1186,45 @@ func testHostsAdditional(t *testing.T, ds *Datastore) {
// Additional should not be loaded for authenticatehost
h, err = ds.AuthenticateHost(context.Background(), "nodekey")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "foobar.local", h.Hostname)
assert.Nil(t, h.Additional)
h, err = ds.Host(context.Background(), h.ID)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, &additional, h.Additional)
// Update besides additional. Additional should be unchanged.
h, err = ds.AuthenticateHost(context.Background(), "nodekey")
require.Nil(t, err)
require.NoError(t, err)
h.Hostname = "baz.local"
err = ds.SaveHost(context.Background(), h)
require.Nil(t, err)
require.NoError(t, err)
h, err = ds.AuthenticateHost(context.Background(), "nodekey")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "baz.local", h.Hostname)
assert.Nil(t, h.Additional)
h, err = ds.Host(context.Background(), h.ID)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, &additional, h.Additional)
// Update additional
additional = json.RawMessage(`{"other": "additional"}`)
h, err = ds.AuthenticateHost(context.Background(), "nodekey")
require.Nil(t, err)
require.NoError(t, err)
h.Additional = &additional
err = saveHostAdditionalDB(context.Background(), ds.writer, h)
require.Nil(t, err)
require.NoError(t, err)
h, err = ds.AuthenticateHost(context.Background(), "nodekey")
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, "baz.local", h.Hostname)
assert.Nil(t, h.Additional)
h, err = ds.Host(context.Background(), h.ID)
require.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, &additional, h.Additional)
}
@ -1214,7 +1240,7 @@ func testHostsByIdentifier(t *testing.T, ds *Datastore) {
UUID: fmt.Sprintf("uuid_%d", i),
Hostname: fmt.Sprintf("hostname_%d", i),
})
require.Nil(t, err)
require.NoError(t, err)
}
var (
@ -1304,10 +1330,10 @@ func testHostsSaveUsers(t *testing.T, ds *Datastore) {
require.NotNil(t, host)
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
assert.Len(t, host.Users, 0)
u1 := fleet.HostUser{
@ -1326,10 +1352,10 @@ func testHostsSaveUsers(t *testing.T, ds *Datastore) {
host.Modified = true
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, host.Users, 2)
test.ElementsMatchSkipID(t, host.Users, []fleet.HostUser{u1, u2})
@ -1338,10 +1364,10 @@ func testHostsSaveUsers(t *testing.T, ds *Datastore) {
host.Modified = true
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, host.Users, 1)
assert.Equal(t, host.Users[0].Uid, u2.Uid)
@ -1350,10 +1376,10 @@ func testHostsSaveUsers(t *testing.T, ds *Datastore) {
host.Modified = true
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, host.Users, 2)
test.ElementsMatchSkipID(t, host.Users, []fleet.HostUser{u1, u2})
}
@ -1374,10 +1400,10 @@ func testHostsSaveUsersWithoutUid(t *testing.T, ds *Datastore) {
require.NotNil(t, host)
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
assert.Len(t, host.Users, 0)
u1 := fleet.HostUser{
@ -1394,10 +1420,10 @@ func testHostsSaveUsersWithoutUid(t *testing.T, ds *Datastore) {
host.Modified = true
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, host.Users, 2)
test.ElementsMatchSkipID(t, host.Users, []fleet.HostUser{u1, u2})
@ -1406,10 +1432,10 @@ func testHostsSaveUsersWithoutUid(t *testing.T, ds *Datastore) {
host.Modified = true
err = ds.SaveHost(context.Background(), host)
require.Nil(t, err)
require.NoError(t, err)
host, err = ds.Host(context.Background(), host.ID)
require.Nil(t, err)
require.NoError(t, err)
require.Len(t, host.Users, 1)
assert.Equal(t, host.Users[0].Uid, u2.Uid)
}