diff --git a/server/kolide/import_config.go b/server/kolide/import_config.go index 34a3ba8e7..7715d0d2a 100644 --- a/server/kolide/import_config.go +++ b/server/kolide/import_config.go @@ -1,9 +1,11 @@ package kolide import ( + "bytes" "context" "errors" "fmt" + "strconv" "strings" ) @@ -98,21 +100,21 @@ const ( // QueryDetails represents the query objects used in the packs and the // schedule section of an osquery configuration. type QueryDetails struct { - Query string `json:"query"` - Interval uint `json:"interval"` + Query string `json:"query"` + Interval OsQueryConfigInt `json:"interval"` // Optional fields - Removed *bool `json:"removed"` - Platform *string `json:"platform"` - Version *string `json:"version"` - Shard *uint `json:"shard"` - Snapshot *bool `json:"snapshot"` + Removed *bool `json:"removed"` + Platform *string `json:"platform"` + Version *string `json:"version"` + Shard *OsQueryConfigInt `json:"shard"` + Snapshot *bool `json:"snapshot"` } // PackDetails represents the "packs" section of an osquery configuration // file. type PackDetails struct { Queries QueryNameToQueryDetailsMap `json:"queries"` - Shard *uint `json:"shard"` + Shard *OsQueryConfigInt `json:"shard"` Version *string `json:"version"` Platform string `json:"platform"` Discovery []string `json:"discovery"` @@ -215,3 +217,18 @@ func (ic *ImportConfig) CollectPacks() (PackNameToPackDetails, error) { } return result, nil } + +// OsQueryConfigInt is provided becase integers in the osquery config file may +// be represented as strings in the json. If we know a particular field is +// supposed to be an Integer, we convert from string to int if we can. +type OsQueryConfigInt uint + +func (c *OsQueryConfigInt) UnmarshalJSON(b []byte) error { + stripped := bytes.Trim(b, `"`) + v, err := strconv.ParseUint(string(stripped), 10, 64) + if err != nil { + return err + } + *c = OsQueryConfigInt(v) + return nil +} diff --git a/server/kolide/import_config_test.go b/server/kolide/import_config_test.go index 4e6bd8d90..3234b9942 100644 --- a/server/kolide/import_config_test.go +++ b/server/kolide/import_config_test.go @@ -1,14 +1,66 @@ package kolide import ( + "bytes" "encoding/json" + "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestIntervalUnmarshal(t *testing.T) { + scenarios := []struct { + name string + testVal interface{} + errExpected bool + expectedResult OsQueryConfigInt + }{ + {"string to uint", "100", false, 100}, + {"float to uint", float64(123), false, 123}, + {"nil to zero value int", nil, false, 0}, + {"invalid string", "hi there", true, 0}, + } + for _, scenario := range scenarios { + t.Run(fmt.Sprintf(": %s", scenario.name), func(tt *testing.T) { + v, e := unmarshalInteger(scenario.testVal) + if scenario.errExpected { + assert.NotNil(t, e) + } else { + require.Nil(t, e) + assert.Equal(t, scenario.expectedResult, v) + } + }) + } +} + +type importIntTest struct { + Val OsQueryConfigInt `json:"val"` +} + +func TestConfigImportInt(t *testing.T) { + buff := bytes.NewBufferString(`{"val":"23"}`) + var ts importIntTest + err := json.NewDecoder(buff).Decode(&ts) + assert.Nil(t, err) + assert.Equal(t, 23, int(ts.Val)) + + buff = bytes.NewBufferString(`{"val":456}`) + err = json.NewDecoder(buff).Decode(&ts) + assert.Nil(t, err) + assert.Equal(t, 456, int(ts.Val)) + + buff = bytes.NewBufferString(`{"val":"hi 456"}`) + err = json.NewDecoder(buff).Decode(&ts) + assert.NotNil(t, err) + +} + func TestPackNameMapUnmarshal(t *testing.T) { + s2p := func(s string) *string { return &s } + u2p := func(ui uint) *OsQueryConfigInt { ci := OsQueryConfigInt(ui); return &ci } + pnm := PackNameMap{ "path": "/this/is/a/path", "details": PackDetails{ @@ -17,8 +69,8 @@ func TestPackNameMapUnmarshal(t *testing.T) { Query: "select from foo", Interval: 100, Removed: new(bool), - Platform: strptr("linux"), - Shard: new(uint), + Platform: s2p("linux"), + Shard: new(OsQueryConfigInt), Snapshot: new(bool), }, }, @@ -41,13 +93,13 @@ func TestPackNameMapUnmarshal(t *testing.T) { Query: "select from foo", Interval: 100, Removed: new(bool), - Platform: strptr("linux"), - Shard: new(uint), + Platform: s2p("linux"), + Shard: new(OsQueryConfigInt), Snapshot: new(bool), }, }, - Shard: uintptr(float64(10)), - Version: strptr("1.0"), + Shard: u2p(10), + Version: s2p("1.0"), Platform: "linux", Discovery: []string{ "select from something", @@ -59,13 +111,13 @@ func TestPackNameMapUnmarshal(t *testing.T) { Query: "select from bar", Interval: 100, Removed: new(bool), - Platform: strptr("linux"), - Shard: new(uint), + Platform: s2p("linux"), + Shard: new(OsQueryConfigInt), Snapshot: new(bool), }, }, - Shard: uintptr(float64(10)), - Version: strptr("1.0"), + Shard: u2p(10), + Version: s2p("1.0"), Platform: "linux", }, } diff --git a/server/kolide/import_config_unmarshaler.go b/server/kolide/import_config_unmarshaler.go index c0c9488e7..53898da00 100644 --- a/server/kolide/import_config_unmarshaler.go +++ b/server/kolide/import_config_unmarshaler.go @@ -3,8 +3,11 @@ package kolide import ( "encoding/json" "errors" + "strconv" ) +var wrongTypeError = errors.New("argument missing or unexpected type") + // UnmarshalJSON custom unmarshaling for PackNameMap will determine whether // the pack section of an osquery config file refers to a file path, or // pack details. Pack details are unmarshalled into into PackDetails structure @@ -20,7 +23,11 @@ func (pnm PackNameMap) UnmarshalJSON(b []byte) error { case string: pnm[key] = t case map[string]interface{}: - pnm[key] = unmarshalPackDetails(t) + val, err := unmarshalPackDetails(t) + if err != nil { + return err + } + pnm[key] = val default: return errors.New("can't unmarshal json") } @@ -28,75 +35,176 @@ func (pnm PackNameMap) UnmarshalJSON(b []byte) error { return nil } -func strptr(v interface{}) *string { +func strptr(v interface{}) (*string, error) { if v == nil { - return nil + return nil, nil } - s := new(string) - *s = v.(string) - return s + s, ok := v.(string) + if !ok { + return nil, wrongTypeError + } + return &s, nil } -func boolptr(v interface{}) *bool { +func boolptr(v interface{}) (*bool, error) { if v == nil { - return nil + return nil, nil } - b := new(bool) - *b = v.(bool) - return b + b, ok := v.(bool) + if !ok { + return nil, wrongTypeError + } + return &b, nil } -func uintptr(v interface{}) *uint { +// We expect a float64 here because of the way JSON represents numbers +func uintptr(v interface{}) (*OsQueryConfigInt, error) { if v == nil { - return nil + return nil, nil } - i := new(uint) - *i = uint(v.(float64)) - return i + i, err := unmarshalInteger(v) + if err != nil { + return nil, err + } + return &i, nil } -func unmarshalPackDetails(v map[string]interface{}) PackDetails { - return PackDetails{ - Queries: unmarshalQueryDetails(v["queries"]), - Shard: uintptr(v["shard"]), - Version: strptr(v["version"]), - Platform: v["platform"].(string), - Discovery: unmarshalDiscovery(v["discovery"]), +// Use this when we expext a string value, in this case nil is an error +func toString(v interface{}) (string, error) { + if s, ok := v.(string); ok { + return s, nil } + return "", wrongTypeError } -func unmarshalDiscovery(val interface{}) []string { +func unmarshalPackDetails(v map[string]interface{}) (PackDetails, error) { + var result PackDetails + queries, err := unmarshalQueryDetails(v["queries"]) + if err != nil { + return result, err + } + discovery, err := unmarshalDiscovery(v["discovery"]) + if err != nil { + return result, err + } + platform, err := toString(v["platform"]) + if err != nil { + return result, err + } + shard, err := uintptr(v["shard"]) + if err != nil { + return result, err + } + version, err := strptr(v["version"]) + if err != nil { + return result, err + } + + result = PackDetails{ + Queries: queries, + Shard: shard, + Version: version, + Platform: platform, + Discovery: discovery, + } + return result, nil +} + +func unmarshalDiscovery(val interface{}) ([]string, error) { var result []string if val == nil { - return result + return result, nil + } + v, ok := val.([]interface{}) + if !ok { + return result, wrongTypeError } - v := val.([]interface{}) for _, val := range v { - result = append(result, val.(string)) + query, err := toString(val) + if err != nil { + return result, err + } + result = append(result, query) } - return result + return result, nil } -func unmarshalQueryDetails(v interface{}) QueryNameToQueryDetailsMap { +func unmarshalQueryDetails(v interface{}) (QueryNameToQueryDetailsMap, error) { + var err error result := make(QueryNameToQueryDetailsMap) if v == nil { - return result + return result, nil } for qn, details := range v.(map[string]interface{}) { - result[qn] = unmarshalQueryDetail(details) + result[qn], err = unmarshalQueryDetail(details) + if err != nil { + return nil, err + } } - return result + return result, nil } -func unmarshalQueryDetail(val interface{}) QueryDetails { - v := val.(map[string]interface{}) - return QueryDetails{ - Query: v["query"].(string), - Interval: uint(v["interval"].(float64)), - Removed: boolptr(v["removed"]), - Platform: strptr(v["platform"]), - Version: strptr(v["version"]), - Shard: uintptr(v["shard"]), - Snapshot: boolptr(v["snapshot"]), +func unmarshalQueryDetail(val interface{}) (QueryDetails, error) { + var result QueryDetails + v, ok := val.(map[string]interface{}) + if !ok { + return result, errors.New("argument was missing or the wrong type") + } + interval, err := unmarshalInteger(v["interval"]) + if err != nil { + return result, err + } + query, err := toString(v["query"]) + if err != nil { + return result, err + } + removed, err := boolptr(v["removed"]) + if err != nil { + return result, err + } + platform, err := strptr(v["platform"]) + if err != nil { + return result, err + } + version, err := strptr(v["version"]) + if err != nil { + return result, err + } + shard, err := uintptr(v["shard"]) + if err != nil { + return result, err + } + snapshot, err := boolptr(v["snapshot"]) + if err != nil { + return result, nil + } + result = QueryDetails{ + Query: query, + Interval: OsQueryConfigInt(interval), + Removed: removed, + Platform: platform, + Version: version, + Shard: shard, + Snapshot: snapshot, + } + return result, nil +} + +// It is valid for the interval can be a string that is convertable to an int, +// or an float64. The float64 is how all numbers in JSON are represented, so +// we need to convert to uint +func unmarshalInteger(val interface{}) (OsQueryConfigInt, error) { + // if interval is nil return zero value + if val == nil { + return OsQueryConfigInt(0), nil + } + switch v := val.(type) { + case string: + i, err := strconv.ParseUint(v, 10, 64) + return OsQueryConfigInt(i), err + case float64: + return OsQueryConfigInt(v), nil + default: + return OsQueryConfigInt(0), wrongTypeError } } diff --git a/server/service/endpoint_import_config_test.go b/server/service/endpoint_import_config_test.go index 2e41d42f9..4a779e9ad 100644 --- a/server/service/endpoint_import_config_test.go +++ b/server/service/endpoint_import_config_test.go @@ -60,6 +60,30 @@ func testImportConfigWithMissingGlob(t *testing.T, r *testResource) { } +func testImportConfigWithIntAsString(t *testing.T, r *testResource) { + + testJSON := ` + { + "config": "{\"options\":{\"host_identifier\":\"hostname\",\"schedule_splay_percent\":10},\"schedule\":{\"macosx_kextstat\":{\"query\":\"SELECT * FROM kernel_extensions;\",\"interval\":\"10\"},\"foobar\":{\"query\":\"SELECT foo, bar, pid FROM foobar_table;\",\"interval\":600}},\"packs\":{\"external_pack\":\"/path/to/external_pack.conf\",\"internal_pack\":{\"discovery\":[\"select pid from processes where name = 'foobar';\",\"select count(*) from users where username like 'www%';\"],\"platform\":\"linux\",\"version\":\"1.5.2\",\"queries\":{\"active_directory\":{\"query\":\"select * from ad_config;\",\"interval\":\"1200\",\"description\":\"Check each user's active directory cached settings.\"}}}},\"decorators\":{\"load\":[\"SELECT version FROM osquery_info\",\"SELECT uuid AS host_uuid FROM system_info\"],\"always\":[\"SELECT user AS username FROM logged_in_users WHERE user <> '' ORDER BY time LIMIT 1;\"],\"interval\":{\"3600\":[\"SELECT total_seconds AS uptime FROM uptime;\"]}},\"yara\":{\"signatures\":{\"sig_group_1\":[\"/Users/wxs/sigs/foo.sig\",\"/Users/wxs/sigs/bar.sig\"],\"sig_group_2\":[\"/Users/wxs/sigs/baz.sig\"]},\"file_paths\":{\"system_binaries\":[\"sig_group_1\"],\"tmp\":[\"sig_group_1\",\"sig_group_2\"]}},\"file_paths\":{\"system_binaries\":[\"/usr/bin/%\",\"/usr/sbin/%\"],\"tmp\":[\"/Users/%/tmp/%%\",\"/tmp/%\"]}}", + "external_pack_configs": { + "external_pack": "{\"discovery\":[\"select pid from processes where name = 'baz';\"],\"platform\":\"linux\",\"version\":\"1.5.2\",\"queries\":{\"something\":{\"query\":\"select * from something;\",\"interval\":1200,\"description\":\"Check something.\"}}}" + } + } + ` + buff := bytes.NewBufferString(testJSON) + req, err := http.NewRequest("POST", r.server.URL+"/api/v1/kolide/osquery/config/import", buff) + require.Nil(t, err) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", r.adminToken)) + client := &http.Client{} + resp, err := client.Do(req) + require.Nil(t, err) + var impResponse importResponse + err = json.NewDecoder(resp.Body).Decode(&impResponse) + require.Nil(t, err) + assert.Equal(t, 2, impResponse.Response.ImportStatusBySection[kolide.YARASigSection].ImportCount) + assert.Equal(t, 4, impResponse.Response.ImportStatusBySection[kolide.DecoratorsSection].ImportCount) +} + func testImportConfig(t *testing.T, r *testResource) { testJSON := ` diff --git a/server/service/endpoint_test.go b/server/service/endpoint_test.go index 9eca2a882..da592bffd 100644 --- a/server/service/endpoint_test.go +++ b/server/service/endpoint_test.go @@ -110,6 +110,7 @@ var testFunctions = [...]func(*testing.T, *testResource){ testImportConfigMissingExternal, testImportConfigWithMissingGlob, testImportConfigWithGlob, + testImportConfigWithIntAsString, testAdminUserSetAdmin, testNonAdminUserSetAdmin, testAdminUserSetEnabled, diff --git a/server/service/service_import_config.go b/server/service/service_import_config.go index a963160bc..4f335c596 100644 --- a/server/service/service_import_config.go +++ b/server/service/service_import_config.go @@ -197,12 +197,12 @@ func (svc service) importScheduledQueries(uid uint, cfg *kolide.ImportConfig, re sq := &kolide.ScheduledQuery{ PackID: pack.ID, QueryID: query.ID, - Interval: queryDetails.Interval, + Interval: uint(queryDetails.Interval), Snapshot: queryDetails.Snapshot, Removed: queryDetails.Removed, Platform: queryDetails.Platform, Version: queryDetails.Version, - Shard: queryDetails.Shard, + Shard: configInt2Ptr(queryDetails.Shard), } _, err = svc.ds.NewScheduledQuery(sq) if err != nil { @@ -317,12 +317,12 @@ func (svc service) createQueriesForPack(uid uint, pack *kolide.Pack, details *ko scheduledQuery := &kolide.ScheduledQuery{ PackID: pack.ID, QueryID: query.ID, - Interval: queryDetails.Interval, + Interval: uint(queryDetails.Interval), Platform: queryDetails.Platform, Snapshot: queryDetails.Snapshot, Removed: queryDetails.Removed, Version: queryDetails.Version, - Shard: queryDetails.Shard, + Shard: configInt2Ptr(queryDetails.Shard), } _, err = svc.ds.NewScheduledQuery(scheduledQuery) if err != nil { @@ -419,3 +419,11 @@ func (svc service) importOptions(opts kolide.OptionNameToValueMap, resp *kolide. } return nil } + +func configInt2Ptr(ci *kolide.OsQueryConfigInt) *uint { + if ci == nil { + return nil + } + ui := uint(*ci) + return &ui +}