Fix issue where config interval can be number or string (#1432)

* Fix issue where config interval can be number or string

* Implemented @groob code review suggestions

* Added type assertions with graceful failure if something slips through validation

* Implemented code review changes per @zwass
This commit is contained in:
John Murphy 2017-03-21 12:02:13 -05:00 committed by GitHub
parent b59cd2b48b
commit 404afe3e22
6 changed files with 274 additions and 64 deletions

View File

@ -1,9 +1,11 @@
package kolide package kolide
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
) )
@ -99,12 +101,12 @@ const (
// schedule section of an osquery configuration. // schedule section of an osquery configuration.
type QueryDetails struct { type QueryDetails struct {
Query string `json:"query"` Query string `json:"query"`
Interval uint `json:"interval"` Interval OsQueryConfigInt `json:"interval"`
// Optional fields // Optional fields
Removed *bool `json:"removed"` Removed *bool `json:"removed"`
Platform *string `json:"platform"` Platform *string `json:"platform"`
Version *string `json:"version"` Version *string `json:"version"`
Shard *uint `json:"shard"` Shard *OsQueryConfigInt `json:"shard"`
Snapshot *bool `json:"snapshot"` Snapshot *bool `json:"snapshot"`
} }
@ -112,7 +114,7 @@ type QueryDetails struct {
// file. // file.
type PackDetails struct { type PackDetails struct {
Queries QueryNameToQueryDetailsMap `json:"queries"` Queries QueryNameToQueryDetailsMap `json:"queries"`
Shard *uint `json:"shard"` Shard *OsQueryConfigInt `json:"shard"`
Version *string `json:"version"` Version *string `json:"version"`
Platform string `json:"platform"` Platform string `json:"platform"`
Discovery []string `json:"discovery"` Discovery []string `json:"discovery"`
@ -215,3 +217,18 @@ func (ic *ImportConfig) CollectPacks() (PackNameToPackDetails, error) {
} }
return result, nil return result, nil
} }
// OsQueryConfigInt is provided becase integers in the osquery config file may
// be represented as strings in the json. If we know a particular field is
// supposed to be an Integer, we convert from string to int if we can.
type OsQueryConfigInt uint
func (c *OsQueryConfigInt) UnmarshalJSON(b []byte) error {
stripped := bytes.Trim(b, `"`)
v, err := strconv.ParseUint(string(stripped), 10, 64)
if err != nil {
return err
}
*c = OsQueryConfigInt(v)
return nil
}

View File

@ -1,14 +1,66 @@
package kolide package kolide
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestIntervalUnmarshal(t *testing.T) {
scenarios := []struct {
name string
testVal interface{}
errExpected bool
expectedResult OsQueryConfigInt
}{
{"string to uint", "100", false, 100},
{"float to uint", float64(123), false, 123},
{"nil to zero value int", nil, false, 0},
{"invalid string", "hi there", true, 0},
}
for _, scenario := range scenarios {
t.Run(fmt.Sprintf(": %s", scenario.name), func(tt *testing.T) {
v, e := unmarshalInteger(scenario.testVal)
if scenario.errExpected {
assert.NotNil(t, e)
} else {
require.Nil(t, e)
assert.Equal(t, scenario.expectedResult, v)
}
})
}
}
type importIntTest struct {
Val OsQueryConfigInt `json:"val"`
}
func TestConfigImportInt(t *testing.T) {
buff := bytes.NewBufferString(`{"val":"23"}`)
var ts importIntTest
err := json.NewDecoder(buff).Decode(&ts)
assert.Nil(t, err)
assert.Equal(t, 23, int(ts.Val))
buff = bytes.NewBufferString(`{"val":456}`)
err = json.NewDecoder(buff).Decode(&ts)
assert.Nil(t, err)
assert.Equal(t, 456, int(ts.Val))
buff = bytes.NewBufferString(`{"val":"hi 456"}`)
err = json.NewDecoder(buff).Decode(&ts)
assert.NotNil(t, err)
}
func TestPackNameMapUnmarshal(t *testing.T) { func TestPackNameMapUnmarshal(t *testing.T) {
s2p := func(s string) *string { return &s }
u2p := func(ui uint) *OsQueryConfigInt { ci := OsQueryConfigInt(ui); return &ci }
pnm := PackNameMap{ pnm := PackNameMap{
"path": "/this/is/a/path", "path": "/this/is/a/path",
"details": PackDetails{ "details": PackDetails{
@ -17,8 +69,8 @@ func TestPackNameMapUnmarshal(t *testing.T) {
Query: "select from foo", Query: "select from foo",
Interval: 100, Interval: 100,
Removed: new(bool), Removed: new(bool),
Platform: strptr("linux"), Platform: s2p("linux"),
Shard: new(uint), Shard: new(OsQueryConfigInt),
Snapshot: new(bool), Snapshot: new(bool),
}, },
}, },
@ -41,13 +93,13 @@ func TestPackNameMapUnmarshal(t *testing.T) {
Query: "select from foo", Query: "select from foo",
Interval: 100, Interval: 100,
Removed: new(bool), Removed: new(bool),
Platform: strptr("linux"), Platform: s2p("linux"),
Shard: new(uint), Shard: new(OsQueryConfigInt),
Snapshot: new(bool), Snapshot: new(bool),
}, },
}, },
Shard: uintptr(float64(10)), Shard: u2p(10),
Version: strptr("1.0"), Version: s2p("1.0"),
Platform: "linux", Platform: "linux",
Discovery: []string{ Discovery: []string{
"select from something", "select from something",
@ -59,13 +111,13 @@ func TestPackNameMapUnmarshal(t *testing.T) {
Query: "select from bar", Query: "select from bar",
Interval: 100, Interval: 100,
Removed: new(bool), Removed: new(bool),
Platform: strptr("linux"), Platform: s2p("linux"),
Shard: new(uint), Shard: new(OsQueryConfigInt),
Snapshot: new(bool), Snapshot: new(bool),
}, },
}, },
Shard: uintptr(float64(10)), Shard: u2p(10),
Version: strptr("1.0"), Version: s2p("1.0"),
Platform: "linux", Platform: "linux",
}, },
} }

View File

@ -3,8 +3,11 @@ package kolide
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"strconv"
) )
var wrongTypeError = errors.New("argument missing or unexpected type")
// UnmarshalJSON custom unmarshaling for PackNameMap will determine whether // UnmarshalJSON custom unmarshaling for PackNameMap will determine whether
// the pack section of an osquery config file refers to a file path, or // the pack section of an osquery config file refers to a file path, or
// pack details. Pack details are unmarshalled into into PackDetails structure // pack details. Pack details are unmarshalled into into PackDetails structure
@ -20,7 +23,11 @@ func (pnm PackNameMap) UnmarshalJSON(b []byte) error {
case string: case string:
pnm[key] = t pnm[key] = t
case map[string]interface{}: case map[string]interface{}:
pnm[key] = unmarshalPackDetails(t) val, err := unmarshalPackDetails(t)
if err != nil {
return err
}
pnm[key] = val
default: default:
return errors.New("can't unmarshal json") return errors.New("can't unmarshal json")
} }
@ -28,75 +35,176 @@ func (pnm PackNameMap) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func strptr(v interface{}) *string { func strptr(v interface{}) (*string, error) {
if v == nil { if v == nil {
return nil return nil, nil
} }
s := new(string) s, ok := v.(string)
*s = v.(string) if !ok {
return s return nil, wrongTypeError
}
return &s, nil
} }
func boolptr(v interface{}) *bool { func boolptr(v interface{}) (*bool, error) {
if v == nil { if v == nil {
return nil return nil, nil
} }
b := new(bool) b, ok := v.(bool)
*b = v.(bool) if !ok {
return b return nil, wrongTypeError
}
return &b, nil
} }
func uintptr(v interface{}) *uint { // We expect a float64 here because of the way JSON represents numbers
func uintptr(v interface{}) (*OsQueryConfigInt, error) {
if v == nil { if v == nil {
return nil return nil, nil
} }
i := new(uint) i, err := unmarshalInteger(v)
*i = uint(v.(float64)) if err != nil {
return i return nil, err
}
return &i, nil
} }
func unmarshalPackDetails(v map[string]interface{}) PackDetails { // Use this when we expext a string value, in this case nil is an error
return PackDetails{ func toString(v interface{}) (string, error) {
Queries: unmarshalQueryDetails(v["queries"]), if s, ok := v.(string); ok {
Shard: uintptr(v["shard"]), return s, nil
Version: strptr(v["version"]),
Platform: v["platform"].(string),
Discovery: unmarshalDiscovery(v["discovery"]),
} }
return "", wrongTypeError
} }
func unmarshalDiscovery(val interface{}) []string { func unmarshalPackDetails(v map[string]interface{}) (PackDetails, error) {
var result PackDetails
queries, err := unmarshalQueryDetails(v["queries"])
if err != nil {
return result, err
}
discovery, err := unmarshalDiscovery(v["discovery"])
if err != nil {
return result, err
}
platform, err := toString(v["platform"])
if err != nil {
return result, err
}
shard, err := uintptr(v["shard"])
if err != nil {
return result, err
}
version, err := strptr(v["version"])
if err != nil {
return result, err
}
result = PackDetails{
Queries: queries,
Shard: shard,
Version: version,
Platform: platform,
Discovery: discovery,
}
return result, nil
}
func unmarshalDiscovery(val interface{}) ([]string, error) {
var result []string var result []string
if val == nil { if val == nil {
return result return result, nil
}
v, ok := val.([]interface{})
if !ok {
return result, wrongTypeError
} }
v := val.([]interface{})
for _, val := range v { for _, val := range v {
result = append(result, val.(string)) query, err := toString(val)
if err != nil {
return result, err
} }
return result result = append(result, query)
}
return result, nil
} }
func unmarshalQueryDetails(v interface{}) QueryNameToQueryDetailsMap { func unmarshalQueryDetails(v interface{}) (QueryNameToQueryDetailsMap, error) {
var err error
result := make(QueryNameToQueryDetailsMap) result := make(QueryNameToQueryDetailsMap)
if v == nil { if v == nil {
return result return result, nil
} }
for qn, details := range v.(map[string]interface{}) { for qn, details := range v.(map[string]interface{}) {
result[qn] = unmarshalQueryDetail(details) result[qn], err = unmarshalQueryDetail(details)
if err != nil {
return nil, err
} }
return result }
return result, nil
} }
func unmarshalQueryDetail(val interface{}) QueryDetails { func unmarshalQueryDetail(val interface{}) (QueryDetails, error) {
v := val.(map[string]interface{}) var result QueryDetails
return QueryDetails{ v, ok := val.(map[string]interface{})
Query: v["query"].(string), if !ok {
Interval: uint(v["interval"].(float64)), return result, errors.New("argument was missing or the wrong type")
Removed: boolptr(v["removed"]), }
Platform: strptr(v["platform"]), interval, err := unmarshalInteger(v["interval"])
Version: strptr(v["version"]), if err != nil {
Shard: uintptr(v["shard"]), return result, err
Snapshot: boolptr(v["snapshot"]), }
query, err := toString(v["query"])
if err != nil {
return result, err
}
removed, err := boolptr(v["removed"])
if err != nil {
return result, err
}
platform, err := strptr(v["platform"])
if err != nil {
return result, err
}
version, err := strptr(v["version"])
if err != nil {
return result, err
}
shard, err := uintptr(v["shard"])
if err != nil {
return result, err
}
snapshot, err := boolptr(v["snapshot"])
if err != nil {
return result, nil
}
result = QueryDetails{
Query: query,
Interval: OsQueryConfigInt(interval),
Removed: removed,
Platform: platform,
Version: version,
Shard: shard,
Snapshot: snapshot,
}
return result, nil
}
// It is valid for the interval can be a string that is convertable to an int,
// or an float64. The float64 is how all numbers in JSON are represented, so
// we need to convert to uint
func unmarshalInteger(val interface{}) (OsQueryConfigInt, error) {
// if interval is nil return zero value
if val == nil {
return OsQueryConfigInt(0), nil
}
switch v := val.(type) {
case string:
i, err := strconv.ParseUint(v, 10, 64)
return OsQueryConfigInt(i), err
case float64:
return OsQueryConfigInt(v), nil
default:
return OsQueryConfigInt(0), wrongTypeError
} }
} }

View File

@ -60,6 +60,30 @@ func testImportConfigWithMissingGlob(t *testing.T, r *testResource) {
} }
func testImportConfigWithIntAsString(t *testing.T, r *testResource) {
testJSON := `
{
"config": "{\"options\":{\"host_identifier\":\"hostname\",\"schedule_splay_percent\":10},\"schedule\":{\"macosx_kextstat\":{\"query\":\"SELECT * FROM kernel_extensions;\",\"interval\":\"10\"},\"foobar\":{\"query\":\"SELECT foo, bar, pid FROM foobar_table;\",\"interval\":600}},\"packs\":{\"external_pack\":\"/path/to/external_pack.conf\",\"internal_pack\":{\"discovery\":[\"select pid from processes where name = 'foobar';\",\"select count(*) from users where username like 'www%';\"],\"platform\":\"linux\",\"version\":\"1.5.2\",\"queries\":{\"active_directory\":{\"query\":\"select * from ad_config;\",\"interval\":\"1200\",\"description\":\"Check each user's active directory cached settings.\"}}}},\"decorators\":{\"load\":[\"SELECT version FROM osquery_info\",\"SELECT uuid AS host_uuid FROM system_info\"],\"always\":[\"SELECT user AS username FROM logged_in_users WHERE user <> '' ORDER BY time LIMIT 1;\"],\"interval\":{\"3600\":[\"SELECT total_seconds AS uptime FROM uptime;\"]}},\"yara\":{\"signatures\":{\"sig_group_1\":[\"/Users/wxs/sigs/foo.sig\",\"/Users/wxs/sigs/bar.sig\"],\"sig_group_2\":[\"/Users/wxs/sigs/baz.sig\"]},\"file_paths\":{\"system_binaries\":[\"sig_group_1\"],\"tmp\":[\"sig_group_1\",\"sig_group_2\"]}},\"file_paths\":{\"system_binaries\":[\"/usr/bin/%\",\"/usr/sbin/%\"],\"tmp\":[\"/Users/%/tmp/%%\",\"/tmp/%\"]}}",
"external_pack_configs": {
"external_pack": "{\"discovery\":[\"select pid from processes where name = 'baz';\"],\"platform\":\"linux\",\"version\":\"1.5.2\",\"queries\":{\"something\":{\"query\":\"select * from something;\",\"interval\":1200,\"description\":\"Check something.\"}}}"
}
}
`
buff := bytes.NewBufferString(testJSON)
req, err := http.NewRequest("POST", r.server.URL+"/api/v1/kolide/osquery/config/import", buff)
require.Nil(t, err)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.adminToken))
client := &http.Client{}
resp, err := client.Do(req)
require.Nil(t, err)
var impResponse importResponse
err = json.NewDecoder(resp.Body).Decode(&impResponse)
require.Nil(t, err)
assert.Equal(t, 2, impResponse.Response.ImportStatusBySection[kolide.YARASigSection].ImportCount)
assert.Equal(t, 4, impResponse.Response.ImportStatusBySection[kolide.DecoratorsSection].ImportCount)
}
func testImportConfig(t *testing.T, r *testResource) { func testImportConfig(t *testing.T, r *testResource) {
testJSON := ` testJSON := `

View File

@ -110,6 +110,7 @@ var testFunctions = [...]func(*testing.T, *testResource){
testImportConfigMissingExternal, testImportConfigMissingExternal,
testImportConfigWithMissingGlob, testImportConfigWithMissingGlob,
testImportConfigWithGlob, testImportConfigWithGlob,
testImportConfigWithIntAsString,
testAdminUserSetAdmin, testAdminUserSetAdmin,
testNonAdminUserSetAdmin, testNonAdminUserSetAdmin,
testAdminUserSetEnabled, testAdminUserSetEnabled,

View File

@ -197,12 +197,12 @@ func (svc service) importScheduledQueries(uid uint, cfg *kolide.ImportConfig, re
sq := &kolide.ScheduledQuery{ sq := &kolide.ScheduledQuery{
PackID: pack.ID, PackID: pack.ID,
QueryID: query.ID, QueryID: query.ID,
Interval: queryDetails.Interval, Interval: uint(queryDetails.Interval),
Snapshot: queryDetails.Snapshot, Snapshot: queryDetails.Snapshot,
Removed: queryDetails.Removed, Removed: queryDetails.Removed,
Platform: queryDetails.Platform, Platform: queryDetails.Platform,
Version: queryDetails.Version, Version: queryDetails.Version,
Shard: queryDetails.Shard, Shard: configInt2Ptr(queryDetails.Shard),
} }
_, err = svc.ds.NewScheduledQuery(sq) _, err = svc.ds.NewScheduledQuery(sq)
if err != nil { if err != nil {
@ -317,12 +317,12 @@ func (svc service) createQueriesForPack(uid uint, pack *kolide.Pack, details *ko
scheduledQuery := &kolide.ScheduledQuery{ scheduledQuery := &kolide.ScheduledQuery{
PackID: pack.ID, PackID: pack.ID,
QueryID: query.ID, QueryID: query.ID,
Interval: queryDetails.Interval, Interval: uint(queryDetails.Interval),
Platform: queryDetails.Platform, Platform: queryDetails.Platform,
Snapshot: queryDetails.Snapshot, Snapshot: queryDetails.Snapshot,
Removed: queryDetails.Removed, Removed: queryDetails.Removed,
Version: queryDetails.Version, Version: queryDetails.Version,
Shard: queryDetails.Shard, Shard: configInt2Ptr(queryDetails.Shard),
} }
_, err = svc.ds.NewScheduledQuery(scheduledQuery) _, err = svc.ds.NewScheduledQuery(scheduledQuery)
if err != nil { if err != nil {
@ -419,3 +419,11 @@ func (svc service) importOptions(opts kolide.OptionNameToValueMap, resp *kolide.
} }
return nil return nil
} }
func configInt2Ptr(ci *kolide.OsQueryConfigInt) *uint {
if ci == nil {
return nil
}
ui := uint(*ci)
return &ui
}