Fix issue where config interval can be number or string (#1432)

* Fix issue where config interval can be number or string

* Implemented @groob code review suggestions

* Added type assertions with graceful failure if something slips through validation

* Implemented code review changes per @zwass
This commit is contained in:
John Murphy 2017-03-21 12:02:13 -05:00 committed by GitHub
parent b59cd2b48b
commit 404afe3e22
6 changed files with 274 additions and 64 deletions

View File

@ -1,9 +1,11 @@
package kolide
import (
"bytes"
"context"
"errors"
"fmt"
"strconv"
"strings"
)
@ -99,12 +101,12 @@ const (
// schedule section of an osquery configuration.
type QueryDetails struct {
Query string `json:"query"`
Interval uint `json:"interval"`
Interval OsQueryConfigInt `json:"interval"`
// Optional fields
Removed *bool `json:"removed"`
Platform *string `json:"platform"`
Version *string `json:"version"`
Shard *uint `json:"shard"`
Shard *OsQueryConfigInt `json:"shard"`
Snapshot *bool `json:"snapshot"`
}
@ -112,7 +114,7 @@ type QueryDetails struct {
// file.
type PackDetails struct {
Queries QueryNameToQueryDetailsMap `json:"queries"`
Shard *uint `json:"shard"`
Shard *OsQueryConfigInt `json:"shard"`
Version *string `json:"version"`
Platform string `json:"platform"`
Discovery []string `json:"discovery"`
@ -215,3 +217,18 @@ func (ic *ImportConfig) CollectPacks() (PackNameToPackDetails, error) {
}
return result, nil
}
// OsQueryConfigInt is provided becase integers in the osquery config file may
// be represented as strings in the json. If we know a particular field is
// supposed to be an Integer, we convert from string to int if we can.
type OsQueryConfigInt uint
func (c *OsQueryConfigInt) UnmarshalJSON(b []byte) error {
stripped := bytes.Trim(b, `"`)
v, err := strconv.ParseUint(string(stripped), 10, 64)
if err != nil {
return err
}
*c = OsQueryConfigInt(v)
return nil
}

View File

@ -1,14 +1,66 @@
package kolide
import (
"bytes"
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIntervalUnmarshal(t *testing.T) {
scenarios := []struct {
name string
testVal interface{}
errExpected bool
expectedResult OsQueryConfigInt
}{
{"string to uint", "100", false, 100},
{"float to uint", float64(123), false, 123},
{"nil to zero value int", nil, false, 0},
{"invalid string", "hi there", true, 0},
}
for _, scenario := range scenarios {
t.Run(fmt.Sprintf(": %s", scenario.name), func(tt *testing.T) {
v, e := unmarshalInteger(scenario.testVal)
if scenario.errExpected {
assert.NotNil(t, e)
} else {
require.Nil(t, e)
assert.Equal(t, scenario.expectedResult, v)
}
})
}
}
type importIntTest struct {
Val OsQueryConfigInt `json:"val"`
}
func TestConfigImportInt(t *testing.T) {
buff := bytes.NewBufferString(`{"val":"23"}`)
var ts importIntTest
err := json.NewDecoder(buff).Decode(&ts)
assert.Nil(t, err)
assert.Equal(t, 23, int(ts.Val))
buff = bytes.NewBufferString(`{"val":456}`)
err = json.NewDecoder(buff).Decode(&ts)
assert.Nil(t, err)
assert.Equal(t, 456, int(ts.Val))
buff = bytes.NewBufferString(`{"val":"hi 456"}`)
err = json.NewDecoder(buff).Decode(&ts)
assert.NotNil(t, err)
}
func TestPackNameMapUnmarshal(t *testing.T) {
s2p := func(s string) *string { return &s }
u2p := func(ui uint) *OsQueryConfigInt { ci := OsQueryConfigInt(ui); return &ci }
pnm := PackNameMap{
"path": "/this/is/a/path",
"details": PackDetails{
@ -17,8 +69,8 @@ func TestPackNameMapUnmarshal(t *testing.T) {
Query: "select from foo",
Interval: 100,
Removed: new(bool),
Platform: strptr("linux"),
Shard: new(uint),
Platform: s2p("linux"),
Shard: new(OsQueryConfigInt),
Snapshot: new(bool),
},
},
@ -41,13 +93,13 @@ func TestPackNameMapUnmarshal(t *testing.T) {
Query: "select from foo",
Interval: 100,
Removed: new(bool),
Platform: strptr("linux"),
Shard: new(uint),
Platform: s2p("linux"),
Shard: new(OsQueryConfigInt),
Snapshot: new(bool),
},
},
Shard: uintptr(float64(10)),
Version: strptr("1.0"),
Shard: u2p(10),
Version: s2p("1.0"),
Platform: "linux",
Discovery: []string{
"select from something",
@ -59,13 +111,13 @@ func TestPackNameMapUnmarshal(t *testing.T) {
Query: "select from bar",
Interval: 100,
Removed: new(bool),
Platform: strptr("linux"),
Shard: new(uint),
Platform: s2p("linux"),
Shard: new(OsQueryConfigInt),
Snapshot: new(bool),
},
},
Shard: uintptr(float64(10)),
Version: strptr("1.0"),
Shard: u2p(10),
Version: s2p("1.0"),
Platform: "linux",
},
}

View File

@ -3,8 +3,11 @@ package kolide
import (
"encoding/json"
"errors"
"strconv"
)
var wrongTypeError = errors.New("argument missing or unexpected type")
// UnmarshalJSON custom unmarshaling for PackNameMap will determine whether
// the pack section of an osquery config file refers to a file path, or
// pack details. Pack details are unmarshalled into into PackDetails structure
@ -20,7 +23,11 @@ func (pnm PackNameMap) UnmarshalJSON(b []byte) error {
case string:
pnm[key] = t
case map[string]interface{}:
pnm[key] = unmarshalPackDetails(t)
val, err := unmarshalPackDetails(t)
if err != nil {
return err
}
pnm[key] = val
default:
return errors.New("can't unmarshal json")
}
@ -28,75 +35,176 @@ func (pnm PackNameMap) UnmarshalJSON(b []byte) error {
return nil
}
func strptr(v interface{}) *string {
func strptr(v interface{}) (*string, error) {
if v == nil {
return nil
return nil, nil
}
s := new(string)
*s = v.(string)
return s
s, ok := v.(string)
if !ok {
return nil, wrongTypeError
}
return &s, nil
}
func boolptr(v interface{}) *bool {
func boolptr(v interface{}) (*bool, error) {
if v == nil {
return nil
return nil, nil
}
b := new(bool)
*b = v.(bool)
return b
b, ok := v.(bool)
if !ok {
return nil, wrongTypeError
}
return &b, nil
}
func uintptr(v interface{}) *uint {
// We expect a float64 here because of the way JSON represents numbers
func uintptr(v interface{}) (*OsQueryConfigInt, error) {
if v == nil {
return nil
return nil, nil
}
i := new(uint)
*i = uint(v.(float64))
return i
i, err := unmarshalInteger(v)
if err != nil {
return nil, err
}
return &i, nil
}
func unmarshalPackDetails(v map[string]interface{}) PackDetails {
return PackDetails{
Queries: unmarshalQueryDetails(v["queries"]),
Shard: uintptr(v["shard"]),
Version: strptr(v["version"]),
Platform: v["platform"].(string),
Discovery: unmarshalDiscovery(v["discovery"]),
// Use this when we expext a string value, in this case nil is an error
func toString(v interface{}) (string, error) {
if s, ok := v.(string); ok {
return s, nil
}
return "", wrongTypeError
}
func unmarshalDiscovery(val interface{}) []string {
func unmarshalPackDetails(v map[string]interface{}) (PackDetails, error) {
var result PackDetails
queries, err := unmarshalQueryDetails(v["queries"])
if err != nil {
return result, err
}
discovery, err := unmarshalDiscovery(v["discovery"])
if err != nil {
return result, err
}
platform, err := toString(v["platform"])
if err != nil {
return result, err
}
shard, err := uintptr(v["shard"])
if err != nil {
return result, err
}
version, err := strptr(v["version"])
if err != nil {
return result, err
}
result = PackDetails{
Queries: queries,
Shard: shard,
Version: version,
Platform: platform,
Discovery: discovery,
}
return result, nil
}
func unmarshalDiscovery(val interface{}) ([]string, error) {
var result []string
if val == nil {
return result
return result, nil
}
v, ok := val.([]interface{})
if !ok {
return result, wrongTypeError
}
v := val.([]interface{})
for _, val := range v {
result = append(result, val.(string))
query, err := toString(val)
if err != nil {
return result, err
}
return result
result = append(result, query)
}
return result, nil
}
func unmarshalQueryDetails(v interface{}) QueryNameToQueryDetailsMap {
func unmarshalQueryDetails(v interface{}) (QueryNameToQueryDetailsMap, error) {
var err error
result := make(QueryNameToQueryDetailsMap)
if v == nil {
return result
return result, nil
}
for qn, details := range v.(map[string]interface{}) {
result[qn] = unmarshalQueryDetail(details)
result[qn], err = unmarshalQueryDetail(details)
if err != nil {
return nil, err
}
return result
}
return result, nil
}
func unmarshalQueryDetail(val interface{}) QueryDetails {
v := val.(map[string]interface{})
return QueryDetails{
Query: v["query"].(string),
Interval: uint(v["interval"].(float64)),
Removed: boolptr(v["removed"]),
Platform: strptr(v["platform"]),
Version: strptr(v["version"]),
Shard: uintptr(v["shard"]),
Snapshot: boolptr(v["snapshot"]),
func unmarshalQueryDetail(val interface{}) (QueryDetails, error) {
var result QueryDetails
v, ok := val.(map[string]interface{})
if !ok {
return result, errors.New("argument was missing or the wrong type")
}
interval, err := unmarshalInteger(v["interval"])
if err != nil {
return result, err
}
query, err := toString(v["query"])
if err != nil {
return result, err
}
removed, err := boolptr(v["removed"])
if err != nil {
return result, err
}
platform, err := strptr(v["platform"])
if err != nil {
return result, err
}
version, err := strptr(v["version"])
if err != nil {
return result, err
}
shard, err := uintptr(v["shard"])
if err != nil {
return result, err
}
snapshot, err := boolptr(v["snapshot"])
if err != nil {
return result, nil
}
result = QueryDetails{
Query: query,
Interval: OsQueryConfigInt(interval),
Removed: removed,
Platform: platform,
Version: version,
Shard: shard,
Snapshot: snapshot,
}
return result, nil
}
// It is valid for the interval can be a string that is convertable to an int,
// or an float64. The float64 is how all numbers in JSON are represented, so
// we need to convert to uint
func unmarshalInteger(val interface{}) (OsQueryConfigInt, error) {
// if interval is nil return zero value
if val == nil {
return OsQueryConfigInt(0), nil
}
switch v := val.(type) {
case string:
i, err := strconv.ParseUint(v, 10, 64)
return OsQueryConfigInt(i), err
case float64:
return OsQueryConfigInt(v), nil
default:
return OsQueryConfigInt(0), wrongTypeError
}
}

View File

@ -60,6 +60,30 @@ func testImportConfigWithMissingGlob(t *testing.T, r *testResource) {
}
func testImportConfigWithIntAsString(t *testing.T, r *testResource) {
testJSON := `
{
"config": "{\"options\":{\"host_identifier\":\"hostname\",\"schedule_splay_percent\":10},\"schedule\":{\"macosx_kextstat\":{\"query\":\"SELECT * FROM kernel_extensions;\",\"interval\":\"10\"},\"foobar\":{\"query\":\"SELECT foo, bar, pid FROM foobar_table;\",\"interval\":600}},\"packs\":{\"external_pack\":\"/path/to/external_pack.conf\",\"internal_pack\":{\"discovery\":[\"select pid from processes where name = 'foobar';\",\"select count(*) from users where username like 'www%';\"],\"platform\":\"linux\",\"version\":\"1.5.2\",\"queries\":{\"active_directory\":{\"query\":\"select * from ad_config;\",\"interval\":\"1200\",\"description\":\"Check each user's active directory cached settings.\"}}}},\"decorators\":{\"load\":[\"SELECT version FROM osquery_info\",\"SELECT uuid AS host_uuid FROM system_info\"],\"always\":[\"SELECT user AS username FROM logged_in_users WHERE user <> '' ORDER BY time LIMIT 1;\"],\"interval\":{\"3600\":[\"SELECT total_seconds AS uptime FROM uptime;\"]}},\"yara\":{\"signatures\":{\"sig_group_1\":[\"/Users/wxs/sigs/foo.sig\",\"/Users/wxs/sigs/bar.sig\"],\"sig_group_2\":[\"/Users/wxs/sigs/baz.sig\"]},\"file_paths\":{\"system_binaries\":[\"sig_group_1\"],\"tmp\":[\"sig_group_1\",\"sig_group_2\"]}},\"file_paths\":{\"system_binaries\":[\"/usr/bin/%\",\"/usr/sbin/%\"],\"tmp\":[\"/Users/%/tmp/%%\",\"/tmp/%\"]}}",
"external_pack_configs": {
"external_pack": "{\"discovery\":[\"select pid from processes where name = 'baz';\"],\"platform\":\"linux\",\"version\":\"1.5.2\",\"queries\":{\"something\":{\"query\":\"select * from something;\",\"interval\":1200,\"description\":\"Check something.\"}}}"
}
}
`
buff := bytes.NewBufferString(testJSON)
req, err := http.NewRequest("POST", r.server.URL+"/api/v1/kolide/osquery/config/import", buff)
require.Nil(t, err)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.adminToken))
client := &http.Client{}
resp, err := client.Do(req)
require.Nil(t, err)
var impResponse importResponse
err = json.NewDecoder(resp.Body).Decode(&impResponse)
require.Nil(t, err)
assert.Equal(t, 2, impResponse.Response.ImportStatusBySection[kolide.YARASigSection].ImportCount)
assert.Equal(t, 4, impResponse.Response.ImportStatusBySection[kolide.DecoratorsSection].ImportCount)
}
func testImportConfig(t *testing.T, r *testResource) {
testJSON := `

View File

@ -110,6 +110,7 @@ var testFunctions = [...]func(*testing.T, *testResource){
testImportConfigMissingExternal,
testImportConfigWithMissingGlob,
testImportConfigWithGlob,
testImportConfigWithIntAsString,
testAdminUserSetAdmin,
testNonAdminUserSetAdmin,
testAdminUserSetEnabled,

View File

@ -197,12 +197,12 @@ func (svc service) importScheduledQueries(uid uint, cfg *kolide.ImportConfig, re
sq := &kolide.ScheduledQuery{
PackID: pack.ID,
QueryID: query.ID,
Interval: queryDetails.Interval,
Interval: uint(queryDetails.Interval),
Snapshot: queryDetails.Snapshot,
Removed: queryDetails.Removed,
Platform: queryDetails.Platform,
Version: queryDetails.Version,
Shard: queryDetails.Shard,
Shard: configInt2Ptr(queryDetails.Shard),
}
_, err = svc.ds.NewScheduledQuery(sq)
if err != nil {
@ -317,12 +317,12 @@ func (svc service) createQueriesForPack(uid uint, pack *kolide.Pack, details *ko
scheduledQuery := &kolide.ScheduledQuery{
PackID: pack.ID,
QueryID: query.ID,
Interval: queryDetails.Interval,
Interval: uint(queryDetails.Interval),
Platform: queryDetails.Platform,
Snapshot: queryDetails.Snapshot,
Removed: queryDetails.Removed,
Version: queryDetails.Version,
Shard: queryDetails.Shard,
Shard: configInt2Ptr(queryDetails.Shard),
}
_, err = svc.ds.NewScheduledQuery(scheduledQuery)
if err != nil {
@ -419,3 +419,11 @@ func (svc service) importOptions(opts kolide.OptionNameToValueMap, resp *kolide.
}
return nil
}
func configInt2Ptr(ci *kolide.OsQueryConfigInt) *uint {
if ci == nil {
return nil
}
ui := uint(*ci)
return &ui
}