mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
a228dcb170
* add premium feature critical policies * update documentation * add test for premium-only field * update old change-file * test policies more comprehensively * also test team policies * PATCH returns wrong timestamp, updating test
302 lines
8.6 KiB
Go
302 lines
8.6 KiB
Go
package service
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"regexp"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/pubsub"
|
|
"github.com/fleetdm/fleet/v4/server/sso"
|
|
"github.com/fleetdm/fleet/v4/server/test"
|
|
"github.com/ghodss/yaml"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
)
|
|
|
|
type withDS struct {
|
|
s *suite.Suite
|
|
ds *mysql.Datastore
|
|
}
|
|
|
|
func (ts *withDS) SetupSuite(dbName string) {
|
|
t := ts.s.T()
|
|
ts.ds = mysql.CreateNamedMySQLDS(t, dbName)
|
|
test.AddAllHostsLabel(t, ts.ds)
|
|
|
|
// setup the required fields on AppConfig
|
|
appConf, err := ts.ds.AppConfig(context.Background())
|
|
require.NoError(t, err)
|
|
appConf.OrgInfo.OrgName = "FleetTest"
|
|
appConf.ServerSettings.ServerURL = "https://example.org"
|
|
err = ts.ds.SaveAppConfig(context.Background(), appConf)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func (ts *withDS) TearDownSuite() {
|
|
ts.ds.Close()
|
|
}
|
|
|
|
type withServer struct {
|
|
withDS
|
|
|
|
server *httptest.Server
|
|
users map[string]fleet.User
|
|
token string
|
|
cachedAdminToken string
|
|
}
|
|
|
|
func (ts *withServer) SetupSuite(dbName string) {
|
|
ts.withDS.SetupSuite(dbName)
|
|
|
|
rs := pubsub.NewInmemQueryResults()
|
|
users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds, &TestServerOpts{Rs: rs})
|
|
ts.server = server
|
|
ts.users = users
|
|
ts.token = ts.getTestAdminToken()
|
|
ts.cachedAdminToken = ts.token
|
|
}
|
|
|
|
func (ts *withServer) TearDownSuite() {
|
|
ts.withDS.TearDownSuite()
|
|
}
|
|
|
|
func (ts *withServer) commonTearDownTest(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
u := ts.users["admin1@example.com"]
|
|
filter := fleet.TeamFilter{User: &u}
|
|
hosts, err := ts.ds.ListHosts(ctx, filter, fleet.HostListOptions{})
|
|
require.NoError(t, err)
|
|
for _, host := range hosts {
|
|
require.NoError(t, ts.ds.UpdateHostSoftware(context.Background(), host.ID, nil))
|
|
require.NoError(t, ts.ds.DeleteHost(ctx, host.ID))
|
|
}
|
|
|
|
// recalculate software counts will remove the software entries
|
|
require.NoError(t, ts.ds.SyncHostsSoftware(context.Background(), time.Now()))
|
|
|
|
lbls, err := ts.ds.ListLabels(ctx, fleet.TeamFilter{}, fleet.ListOptions{})
|
|
require.NoError(t, err)
|
|
for _, lbl := range lbls {
|
|
if lbl.LabelType != fleet.LabelTypeBuiltIn {
|
|
err := ts.ds.DeleteLabel(ctx, lbl.Name)
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
users, err := ts.ds.ListUsers(ctx, fleet.UserListOptions{})
|
|
require.NoError(t, err)
|
|
for _, u := range users {
|
|
if _, ok := ts.users[u.Email]; !ok {
|
|
err := ts.ds.DeleteUser(ctx, u.ID)
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
teams, err := ts.ds.ListTeams(ctx, fleet.TeamFilter{User: &u}, fleet.ListOptions{})
|
|
require.NoError(t, err)
|
|
for _, tm := range teams {
|
|
err := ts.ds.DeleteTeam(ctx, tm.ID)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
globalPolicies, err := ts.ds.ListGlobalPolicies(ctx)
|
|
require.NoError(t, err)
|
|
if len(globalPolicies) > 0 {
|
|
var globalPolicyIDs []uint
|
|
for _, gp := range globalPolicies {
|
|
globalPolicyIDs = append(globalPolicyIDs, gp.ID)
|
|
}
|
|
_, err = ts.ds.DeleteGlobalPolicies(ctx, globalPolicyIDs)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// SyncHostsSoftware performs a cleanup.
|
|
err = ts.ds.SyncHostsSoftware(ctx, time.Now())
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func (ts *withServer) Do(verb, path string, params interface{}, expectedStatusCode int, queryParams ...string) *http.Response {
|
|
t := ts.s.T()
|
|
|
|
j, err := json.Marshal(params)
|
|
require.NoError(t, err)
|
|
|
|
resp := ts.DoRaw(verb, path, j, expectedStatusCode, queryParams...)
|
|
|
|
t.Cleanup(func() {
|
|
resp.Body.Close()
|
|
})
|
|
return resp
|
|
}
|
|
|
|
func (ts *withServer) DoRawWithHeaders(
|
|
verb string, path string, rawBytes []byte, expectedStatusCode int, headers map[string]string, queryParams ...string,
|
|
) *http.Response {
|
|
t := ts.s.T()
|
|
|
|
requestBody := io.NopCloser(bytes.NewBuffer(rawBytes))
|
|
req, err := http.NewRequest(verb, ts.server.URL+path, requestBody)
|
|
require.NoError(t, err)
|
|
for key, val := range headers {
|
|
req.Header.Add(key, val)
|
|
}
|
|
client := fleethttp.NewClient()
|
|
|
|
if len(queryParams)%2 != 0 {
|
|
require.Fail(t, "need even number of params: key value")
|
|
}
|
|
if len(queryParams) > 0 {
|
|
q := req.URL.Query()
|
|
for i := 0; i < len(queryParams); i += 2 {
|
|
q.Add(queryParams[i], queryParams[i+1])
|
|
}
|
|
req.URL.RawQuery = q.Encode()
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, expectedStatusCode, resp.StatusCode)
|
|
|
|
return resp
|
|
}
|
|
|
|
func (ts *withServer) DoRaw(verb string, path string, rawBytes []byte, expectedStatusCode int, queryParams ...string) *http.Response {
|
|
return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, map[string]string{
|
|
"Authorization": fmt.Sprintf("Bearer %s", ts.token),
|
|
}, queryParams...)
|
|
}
|
|
|
|
func (ts *withServer) DoRawNoAuth(verb string, path string, rawBytes []byte, expectedStatusCode int) *http.Response {
|
|
return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, nil)
|
|
}
|
|
|
|
func (ts *withServer) DoJSON(verb, path string, params interface{}, expectedStatusCode int, v interface{}, queryParams ...string) {
|
|
resp := ts.Do(verb, path, params, expectedStatusCode, queryParams...)
|
|
err := json.NewDecoder(resp.Body).Decode(v)
|
|
require.NoError(ts.s.T(), err)
|
|
if e, ok := v.(errorer); ok {
|
|
require.NoError(ts.s.T(), e.error())
|
|
}
|
|
}
|
|
|
|
func (ts *withServer) getTestAdminToken() string {
|
|
testUser := testUsers["admin1"]
|
|
|
|
// because the login endpoint is rate-limited, use the cached admin token
|
|
// if available (if for some reason a test needs to logout the admin user,
|
|
// then set cachedAdminToken = "" so that a new token is retrieved).
|
|
if ts.cachedAdminToken == "" {
|
|
ts.cachedAdminToken = ts.getTestToken(testUser.Email, testUser.PlaintextPassword)
|
|
}
|
|
return ts.cachedAdminToken
|
|
}
|
|
|
|
func (ts *withServer) getTestToken(email string, password string) string {
|
|
params := loginRequest{
|
|
Email: email,
|
|
Password: password,
|
|
}
|
|
j, err := json.Marshal(¶ms)
|
|
require.NoError(ts.s.T(), err)
|
|
|
|
requestBody := io.NopCloser(bytes.NewBuffer(j))
|
|
resp, err := http.Post(ts.server.URL+"/api/latest/fleet/login", "application/json", requestBody)
|
|
require.NoError(ts.s.T(), err)
|
|
defer resp.Body.Close()
|
|
assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode)
|
|
|
|
jsn := struct {
|
|
User *fleet.User `json:"user"`
|
|
Token string `json:"token"`
|
|
Err []map[string]string `json:"errors,omitempty"`
|
|
}{}
|
|
err = json.NewDecoder(resp.Body).Decode(&jsn)
|
|
require.NoError(ts.s.T(), err)
|
|
require.Len(ts.s.T(), jsn.Err, 0)
|
|
|
|
return jsn.Token
|
|
}
|
|
|
|
func (ts *withServer) applyConfig(spec []byte) {
|
|
var appConfigSpec interface{}
|
|
err := yaml.Unmarshal(spec, &appConfigSpec)
|
|
require.NoError(ts.s.T(), err)
|
|
|
|
ts.Do("PATCH", "/api/latest/fleet/config", appConfigSpec, http.StatusOK)
|
|
}
|
|
|
|
func (ts *withServer) getConfig() *appConfigResponse {
|
|
var responseBody *appConfigResponse
|
|
ts.DoJSON("GET", "/api/latest/fleet/config", nil, http.StatusOK, &responseBody)
|
|
return responseBody
|
|
}
|
|
|
|
func (ts *withServer) LoginSSOUser(username, password string) (fleet.Auth, string) {
|
|
t := ts.s.T()
|
|
|
|
if _, ok := os.LookupEnv("SAML_IDP_TEST"); !ok {
|
|
t.Skip("SSO tests are disabled")
|
|
}
|
|
|
|
var resIni initiateSSOResponse
|
|
ts.DoJSON("POST", "/api/v1/fleet/sso", map[string]string{}, http.StatusOK, &resIni)
|
|
|
|
jar, err := cookiejar.New(nil)
|
|
require.NoError(t, err)
|
|
|
|
client := fleethttp.NewClient(
|
|
fleethttp.WithFollowRedir(false),
|
|
fleethttp.WithCookieJar(jar),
|
|
)
|
|
|
|
resp, err := client.Get(resIni.URL)
|
|
require.NoError(t, err)
|
|
|
|
// From the redirect Location header we can get the AuthState and the URL to
|
|
// which we submit the credentials
|
|
parsed, err := url.Parse(resp.Header.Get("Location"))
|
|
require.NoError(t, err)
|
|
data := url.Values{
|
|
"username": {username},
|
|
"password": {password},
|
|
"AuthState": {parsed.Query().Get("AuthState")},
|
|
}
|
|
resp, err = client.PostForm(parsed.Scheme+"://"+parsed.Host+parsed.Path, data)
|
|
require.NoError(t, err)
|
|
|
|
// The response is an HTML form, we can extract the base64-encoded response
|
|
// to submit to the Fleet server from here
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
re := regexp.MustCompile(`value="(.*)"`)
|
|
matches := re.FindSubmatch(body)
|
|
require.NotEmptyf(t, matches, "callback HTML doesn't contain a SAMLResponse value, got body: %s", body)
|
|
rawSSOResp := string(matches[1])
|
|
|
|
auth, err := sso.DecodeAuthResponse(rawSSOResp)
|
|
require.NoError(t, err)
|
|
q := url.QueryEscape(rawSSOResp)
|
|
res := ts.DoRawNoAuth("POST", "/api/v1/fleet/sso/callback?SAMLResponse="+q, nil, http.StatusOK)
|
|
|
|
defer res.Body.Close()
|
|
body, err = io.ReadAll(res.Body)
|
|
require.NoError(t, err)
|
|
return auth, string(body)
|
|
}
|