fleet/server/webhooks/vulnerabilities_test.go
2022-12-05 16:50:49 -06:00

206 lines
5.7 KiB
Go

package webhooks
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
kitlog "github.com/go-kit/kit/log"
"github.com/stretchr/testify/require"
"github.com/tj/assert"
)
func TestTriggerVulnerabilitiesWebhook(t *testing.T) {
ctx := context.Background()
ds := new(mock.Store)
logger := kitlog.NewNopLogger()
mapper := Mapper{}
appCfg := &fleet.AppConfig{
WebhookSettings: fleet.WebhookSettings{
VulnerabilitiesWebhook: fleet.VulnerabilitiesWebhookSettings{
Enable: true,
HostBatchSize: 2,
},
},
ServerSettings: fleet.ServerSettings{
ServerURL: "https://fleet.example.com",
},
}
recentVulns := []fleet.SoftwareVulnerability{
{SoftwareID: 1, CVE: "CVE-2012-1234"},
{SoftwareID: 2, CVE: "CVE-2012-1234"},
}
t.Run("disabled", func(t *testing.T) {
appCfg := *appCfg
appCfg.WebhookSettings.VulnerabilitiesWebhook.Enable = false
args := VulnArgs{
Vulnerablities: recentVulns,
Meta: nil,
AppConfig: &appCfg,
Time: time.Now(),
}
err := TriggerVulnerabilitiesWebhook(ctx, ds, logger, args, &mapper)
require.NoError(t, err)
})
t.Run("invalid server url", func(t *testing.T) {
appCfg := *appCfg
appCfg.ServerSettings.ServerURL = ":nope:"
args := VulnArgs{
Vulnerablities: recentVulns,
Meta: nil,
AppConfig: &appCfg,
Time: time.Now(),
}
err := TriggerVulnerabilitiesWebhook(ctx, ds, logger, args, &mapper)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid server")
})
t.Run("empty recent vulns", func(t *testing.T) {
args := VulnArgs{
Vulnerablities: nil,
Meta: nil,
AppConfig: appCfg,
Time: time.Now(),
}
err := TriggerVulnerabilitiesWebhook(ctx, ds, logger, args, &mapper)
require.NoError(t, err)
})
t.Run("trigger requests", func(t *testing.T) {
now := time.Now()
hosts := []*fleet.HostShort{
{ID: 1, Hostname: "h1", DisplayName: "d1"},
{ID: 2, Hostname: "h2", DisplayName: "d2"},
{ID: 3, Hostname: "h3", DisplayName: "d3"},
{ID: 4, Hostname: "h4", DisplayName: "d4"},
}
jsonH1 := fmt.Sprintf(`{"id":1,"hostname":"h1","display_name":"d1","url":"%s/hosts/1"}`, appCfg.ServerSettings.ServerURL)
jsonH2 := fmt.Sprintf(`{"id":2,"hostname":"h2","display_name":"d2","url":"%s/hosts/2"}`, appCfg.ServerSettings.ServerURL)
jsonH3 := fmt.Sprintf(`{"id":3,"hostname":"h3","display_name":"d3","url":"%s/hosts/3"}`, appCfg.ServerSettings.ServerURL)
jsonH4 := fmt.Sprintf(`{"id":4,"hostname":"h4","display_name":"d4","url":"%s/hosts/4"}`, appCfg.ServerSettings.ServerURL)
cves := []string{
"CVE-2012-1234",
"CVE-2012-4567",
}
jsonCVE1 := fmt.Sprintf(`{"timestamp":"%s","vulnerability":{"cve":%q,"details_link":"https://nvd.nist.gov/vuln/detail/%[2]s","hosts_affected":`,
now.Format(time.RFC3339Nano), cves[0])
jsonCVE2 := fmt.Sprintf(`{"timestamp":"%s","vulnerability":{"cve":%q,"details_link":"https://nvd.nist.gov/vuln/detail/%[2]s","hosts_affected":`,
now.Format(time.RFC3339Nano), cves[1])
cases := []struct {
name string
vulns []fleet.SoftwareVulnerability
meta map[string]fleet.CVEMeta
hosts []*fleet.HostShort
want string
}{
{
"1 vuln, 1 host",
[]fleet.SoftwareVulnerability{{CVE: cves[0], SoftwareID: 1}},
nil,
hosts[:1],
fmt.Sprintf("%s[%s]}}", jsonCVE1, jsonH1),
},
{
"1 vuln in multiple software, 1 host",
[]fleet.SoftwareVulnerability{
{CVE: cves[0], SoftwareID: 1},
{CVE: cves[0], SoftwareID: 1},
{CVE: cves[0], SoftwareID: 2},
},
nil,
hosts[:1],
fmt.Sprintf("%s[%s]}}", jsonCVE1, jsonH1),
},
{
"1 vuln, 2 hosts",
[]fleet.SoftwareVulnerability{
{CVE: cves[0], SoftwareID: 1},
},
nil,
hosts[:2],
fmt.Sprintf("%s[%s,%s]}}", jsonCVE1, jsonH1, jsonH2),
},
{
"1 vuln, 3 hosts",
[]fleet.SoftwareVulnerability{
{CVE: cves[0], SoftwareID: 1},
},
nil,
hosts[:3],
fmt.Sprintf("%s[%s,%s]}}\n%s[%s]}}", jsonCVE1, jsonH1, jsonH2, jsonCVE1, jsonH3), // 2 requests, batch of 2 max
},
{
"1 vuln, 4 hosts",
[]fleet.SoftwareVulnerability{
{CVE: cves[0], SoftwareID: 1},
},
nil,
hosts[:4],
fmt.Sprintf("%s[%s,%s]}}\n%s[%s,%s]}}", jsonCVE1, jsonH1, jsonH2, jsonCVE1, jsonH3, jsonH4), // 2 requests, batch of 2 max
},
{
"2 vulns, 1 host each",
[]fleet.SoftwareVulnerability{
{CVE: cves[0], SoftwareID: 1},
{CVE: cves[1], SoftwareID: 2},
},
nil,
hosts[:1],
fmt.Sprintf("%s[%s]}}\n%s[%s]}}", jsonCVE1, jsonH1, jsonCVE2, jsonH1),
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var requests []string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
requests = append(requests, string(b))
_, err = w.Write(nil)
assert.NoError(t, err)
}))
defer srv.Close()
ds.HostsBySoftwareIDsFunc = func(ctx context.Context, softwareIDs []uint) ([]*fleet.HostShort, error) {
return c.hosts, nil
}
appCfg := *appCfg
appCfg.WebhookSettings.VulnerabilitiesWebhook.DestinationURL = srv.URL
args := VulnArgs{
Vulnerablities: c.vulns,
Meta: c.meta,
AppConfig: &appCfg,
Time: now,
}
err := TriggerVulnerabilitiesWebhook(ctx, ds, logger, args, &mapper)
require.NoError(t, err)
assert.True(t, ds.HostsBySoftwareIDsFuncInvoked)
ds.HostsBySoftwareIDsFuncInvoked = false
want := strings.Split(c.want, "\n")
assert.ElementsMatch(t, want, requests)
})
}
})
}