mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
196 lines
5.5 KiB
Go
196 lines
5.5 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/datastore/cached_mysql"
|
|
"github.com/fleetdm/fleet/v4/server/datastore/mysql"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
nanodepClient "github.com/fleetdm/fleet/v4/server/mdm/nanodep/client"
|
|
"github.com/fleetdm/fleet/v4/server/mock"
|
|
"github.com/fleetdm/fleet/v4/server/service"
|
|
"github.com/fleetdm/fleet/v4/server/test"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
"github.com/urfave/cli/v2"
|
|
)
|
|
|
|
type withDS struct {
|
|
suite *suite.Suite
|
|
ds *mysql.Datastore
|
|
}
|
|
|
|
func (ts *withDS) SetupSuite(dbName string) {
|
|
t := ts.suite.T()
|
|
ts.ds = mysql.CreateNamedMySQLDS(t, dbName)
|
|
test.AddAllHostsLabel(t, ts.ds)
|
|
|
|
// Set up the required fields on AppConfig
|
|
appConf, err := ts.ds.AppConfig(context.Background())
|
|
require.NoError(t, err)
|
|
appConf.OrgInfo.OrgName = "FleetTest"
|
|
appConf.ServerSettings.ServerURL = "https://example.org"
|
|
err = ts.ds.SaveAppConfig(context.Background(), appConf)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func (ts *withDS) TearDownSuite() {
|
|
_ = ts.ds.Close()
|
|
}
|
|
|
|
type withServer struct {
|
|
withDS
|
|
|
|
server *httptest.Server
|
|
users map[string]fleet.User
|
|
}
|
|
|
|
type loginRequest struct {
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
func (ts *withServer) getTestToken(email string, password string) string {
|
|
params := loginRequest{
|
|
Email: email,
|
|
Password: password,
|
|
}
|
|
j, err := json.Marshal(¶ms)
|
|
require.NoError(ts.suite.T(), err)
|
|
|
|
requestBody := io.NopCloser(bytes.NewBuffer(j))
|
|
resp, err := http.Post(ts.server.URL+"/api/latest/fleet/login", "application/json", requestBody)
|
|
require.NoError(ts.suite.T(), err)
|
|
defer func() { _ = resp.Body.Close() }()
|
|
assert.Equal(ts.suite.T(), http.StatusOK, resp.StatusCode)
|
|
|
|
jsn := struct {
|
|
User *fleet.User `json:"user"`
|
|
Token string `json:"token"`
|
|
Err []map[string]string `json:"errors,omitempty"`
|
|
}{}
|
|
err = json.NewDecoder(resp.Body).Decode(&jsn)
|
|
require.NoError(ts.suite.T(), err)
|
|
require.Len(ts.suite.T(), jsn.Err, 0)
|
|
|
|
return jsn.Token
|
|
}
|
|
|
|
var testBMToken = &nanodepClient.OAuth1Tokens{
|
|
ConsumerKey: "test_consumer",
|
|
ConsumerSecret: "test_secret",
|
|
AccessToken: "test_access_token",
|
|
AccessSecret: "test_access_secret",
|
|
AccessTokenExpiry: time.Date(2999, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
}
|
|
|
|
// runServerWithMockedDS runs the fleet server with several mocked DS methods.
|
|
//
|
|
// NOTE: Assumes the current session is always from the admin user (see ds.SessionByKeyFunc below).
|
|
func runServerWithMockedDS(t *testing.T, opts ...*service.TestServerOpts) (*httptest.Server, *mock.Store) {
|
|
ds := new(mock.Store)
|
|
var users []*fleet.User
|
|
var admin *fleet.User
|
|
ds.NewUserFunc = func(ctx context.Context, user *fleet.User) (*fleet.User, error) {
|
|
if user.GlobalRole != nil && *user.GlobalRole == fleet.RoleAdmin {
|
|
admin = user
|
|
}
|
|
users = append(users, user)
|
|
return user, nil
|
|
}
|
|
ds.SessionByKeyFunc = func(ctx context.Context, key string) (*fleet.Session, error) {
|
|
return &fleet.Session{
|
|
CreateTimestamp: fleet.CreateTimestamp{CreatedAt: time.Now()},
|
|
ID: 1,
|
|
AccessedAt: time.Now(),
|
|
UserID: admin.ID,
|
|
Key: key,
|
|
}, nil
|
|
}
|
|
ds.MarkSessionAccessedFunc = func(ctx context.Context, session *fleet.Session) error {
|
|
return nil
|
|
}
|
|
ds.UserByIDFunc = func(ctx context.Context, id uint) (*fleet.User, error) {
|
|
return admin, nil
|
|
}
|
|
ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
|
|
return users, nil
|
|
}
|
|
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
|
return &fleet.AppConfig{}, nil
|
|
}
|
|
|
|
var cachedDS fleet.Datastore
|
|
if len(opts) > 0 && opts[0].NoCacheDatastore {
|
|
cachedDS = ds
|
|
} else {
|
|
cachedDS = cached_mysql.New(ds)
|
|
}
|
|
_, server := service.RunServerForTestsWithDS(t, cachedDS, opts...)
|
|
os.Setenv("FLEET_SERVER_ADDRESS", server.URL)
|
|
|
|
return server, ds
|
|
}
|
|
|
|
func runAppForTest(t *testing.T, args []string) string {
|
|
w, err := runAppNoChecks(args)
|
|
require.NoError(t, err)
|
|
return w.String()
|
|
}
|
|
|
|
func runAppCheckErr(t *testing.T, args []string, errorMsg string) string {
|
|
w, err := runAppNoChecks(args)
|
|
require.Error(t, err)
|
|
require.Equal(t, errorMsg, err.Error())
|
|
return w.String()
|
|
}
|
|
|
|
func runAppNoChecks(args []string) (*bytes.Buffer, error) {
|
|
// first arg must be the binary name. Allow tests to omit it.
|
|
args = append([]string{""}, args...)
|
|
|
|
w := new(bytes.Buffer)
|
|
app := createApp(nil, w, os.Stderr, noopExitErrHandler)
|
|
err := app.Run(args)
|
|
return w, err
|
|
}
|
|
|
|
func runWithErrWriter(args []string, errWriter io.Writer) (*bytes.Buffer, error) {
|
|
args = append([]string{""}, args...)
|
|
|
|
w := new(bytes.Buffer)
|
|
app := createApp(nil, w, errWriter, noopExitErrHandler)
|
|
err := app.Run(args)
|
|
return w, err
|
|
}
|
|
|
|
func noopExitErrHandler(c *cli.Context, err error) {}
|
|
|
|
func serveMDMBootstrapPackage(t *testing.T, pkgPath, pkgName string) (*httptest.Server, int) {
|
|
pkgBytes, err := os.ReadFile(pkgPath)
|
|
require.NoError(t, err)
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Length", strconv.Itoa(len(pkgBytes)))
|
|
w.Header().Set("Content-Type", "application/octet-stream")
|
|
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment;filename="%s"`, pkgName))
|
|
if n, err := w.Write(pkgBytes); err != nil {
|
|
require.NoError(t, err)
|
|
require.Equal(t, len(pkgBytes), n)
|
|
}
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
return srv, len(pkgBytes)
|
|
}
|