mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Migrate remaining user-authenticated endpoints (#3796)
This commit is contained in:
parent
3204ff8e0c
commit
8b8cebb6fe
@ -2,6 +2,7 @@ package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
@ -16,6 +17,9 @@ func (d *Datastore) SessionByKey(ctx context.Context, key string) (*fleet.Sessio
|
||||
session := &fleet.Session{}
|
||||
err := sqlx.GetContext(ctx, d.reader, session, sqlStatement, key)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ctxerr.Wrap(ctx, notFound("Session").WithName("<key redacted>"))
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, err, "selecting sessions")
|
||||
}
|
||||
|
||||
@ -31,6 +35,9 @@ func (d *Datastore) SessionByID(ctx context.Context, id uint) (*fleet.Session, e
|
||||
session := &fleet.Session{}
|
||||
err := sqlx.GetContext(ctx, d.reader, session, sqlStatement, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ctxerr.Wrap(ctx, notFound("Session").WithID(id))
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, err, "selecting session by id")
|
||||
}
|
||||
|
||||
|
434
server/service/appconfig.go
Normal file
434
server/service/appconfig.go
Normal file
@ -0,0 +1,434 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/kolide/kit/version"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get AppConfig
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type appConfigResponse struct {
|
||||
fleet.AppConfig
|
||||
|
||||
UpdateInterval *fleet.UpdateIntervalConfig `json:"update_interval"`
|
||||
Vulnerabilities *fleet.VulnerabilitiesConfig `json:"vulnerabilities"`
|
||||
|
||||
// License is loaded from the service
|
||||
License *fleet.LicenseInfo `json:"license,omitempty"`
|
||||
// Logging is loaded on the fly rather than from the database.
|
||||
Logging *fleet.Logging `json:"logging,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r appConfigResponse) error() error { return r.Err }
|
||||
|
||||
func getAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not fetch user")
|
||||
}
|
||||
config, err := svc.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
license, err := svc.License(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loggingConfig, err := svc.LoggingConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updateIntervalConfig, err := svc.UpdateIntervalConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vulnConfig, err := svc.VulnerabilitiesConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var smtpSettings fleet.SMTPSettings
|
||||
var ssoSettings fleet.SSOSettings
|
||||
var hostExpirySettings fleet.HostExpirySettings
|
||||
var agentOptions *json.RawMessage
|
||||
// only admin can see smtp, sso, and host expiry settings
|
||||
if vc.User.GlobalRole != nil && *vc.User.GlobalRole == fleet.RoleAdmin {
|
||||
smtpSettings = config.SMTPSettings
|
||||
if smtpSettings.SMTPPassword != "" {
|
||||
smtpSettings.SMTPPassword = "********"
|
||||
}
|
||||
ssoSettings = config.SSOSettings
|
||||
hostExpirySettings = config.HostExpirySettings
|
||||
agentOptions = config.AgentOptions
|
||||
}
|
||||
hostSettings := config.HostSettings
|
||||
response := appConfigResponse{
|
||||
AppConfig: fleet.AppConfig{
|
||||
OrgInfo: config.OrgInfo,
|
||||
ServerSettings: config.ServerSettings,
|
||||
HostSettings: hostSettings,
|
||||
VulnerabilitySettings: config.VulnerabilitySettings,
|
||||
|
||||
SMTPSettings: smtpSettings,
|
||||
SSOSettings: ssoSettings,
|
||||
HostExpirySettings: hostExpirySettings,
|
||||
AgentOptions: agentOptions,
|
||||
|
||||
WebhookSettings: config.WebhookSettings,
|
||||
},
|
||||
UpdateInterval: updateIntervalConfig,
|
||||
Vulnerabilities: vulnConfig,
|
||||
License: license,
|
||||
Logging: loggingConfig,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (svc *Service) AppConfig(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.AppConfig{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return svc.ds.AppConfig(ctx)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Modify AppConfig
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type modifyAppConfigRequest struct {
|
||||
json.RawMessage
|
||||
}
|
||||
|
||||
func modifyAppConfigEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*modifyAppConfigRequest)
|
||||
config, err := svc.ModifyAppConfig(ctx, req.RawMessage)
|
||||
if err != nil {
|
||||
return appConfigResponse{Err: err}, nil
|
||||
}
|
||||
license, err := svc.License(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loggingConfig, err := svc.LoggingConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response := appConfigResponse{
|
||||
AppConfig: *config,
|
||||
License: license,
|
||||
Logging: loggingConfig,
|
||||
}
|
||||
|
||||
if response.SMTPSettings.SMTPPassword != "" {
|
||||
response.SMTPSettings.SMTPPassword = "********"
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte) (*fleet.AppConfig, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.AppConfig{}, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
appConfig, err := svc.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(mna): this ports the validations from the old validationMiddleware
|
||||
// correctly, but this could be optimized so that we don't unmarshal the
|
||||
// incoming bytes twice.
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
var newAppConfig fleet.AppConfig
|
||||
if err := json.Unmarshal(p, &newAppConfig); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err)
|
||||
}
|
||||
validateSSOSettings(newAppConfig, appConfig, invalid)
|
||||
if invalid.HasErrors() {
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
|
||||
// We apply the config that is incoming to the old one
|
||||
decoder := json.NewDecoder(bytes.NewReader(p))
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(&appConfig); err != nil {
|
||||
return nil, &badRequestError{message: err.Error()}
|
||||
}
|
||||
|
||||
if appConfig.SMTPSettings.SMTPEnabled || appConfig.SMTPSettings.SMTPConfigured {
|
||||
if err = svc.sendTestEmail(ctx, appConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appConfig.SMTPSettings.SMTPConfigured = true
|
||||
} else if appConfig.SMTPSettings.SMTPEnabled {
|
||||
appConfig.SMTPSettings.SMTPConfigured = false
|
||||
}
|
||||
|
||||
if err := svc.ds.SaveAppConfig(ctx, appConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return appConfig, nil
|
||||
}
|
||||
|
||||
func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *fleet.InvalidArgumentError) {
|
||||
if p.SSOSettings.EnableSSO {
|
||||
if p.SSOSettings.Metadata == "" && p.SSOSettings.MetadataURL == "" {
|
||||
if existing.SSOSettings.Metadata == "" && existing.SSOSettings.MetadataURL == "" {
|
||||
invalid.Append("metadata", "either metadata or metadata_url must be defined")
|
||||
}
|
||||
}
|
||||
if p.SSOSettings.Metadata != "" && p.SSOSettings.MetadataURL != "" {
|
||||
invalid.Append("metadata", "both metadata and metadata_url are defined, only one is allowed")
|
||||
}
|
||||
if p.SSOSettings.EntityID == "" {
|
||||
if existing.SSOSettings.EntityID == "" {
|
||||
invalid.Append("entity_id", "required")
|
||||
}
|
||||
} else {
|
||||
if len(p.SSOSettings.EntityID) < 5 {
|
||||
invalid.Append("entity_id", "must be 5 or more characters")
|
||||
}
|
||||
}
|
||||
if p.SSOSettings.IDPName == "" {
|
||||
if existing.SSOSettings.IDPName == "" {
|
||||
invalid.Append("idp_name", "required")
|
||||
}
|
||||
} else {
|
||||
if len(p.SSOSettings.IDPName) < 4 {
|
||||
invalid.Append("idp_name", "must be 4 or more characters")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Apply enroll secret spec
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type applyEnrollSecretSpecRequest struct {
|
||||
Spec *fleet.EnrollSecretSpec `json:"spec"`
|
||||
}
|
||||
|
||||
type applyEnrollSecretSpecResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r applyEnrollSecretSpecResponse) error() error { return r.Err }
|
||||
|
||||
func applyEnrollSecretSpecEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*applyEnrollSecretSpecRequest)
|
||||
err := svc.ApplyEnrollSecretSpec(ctx, req.Spec)
|
||||
if err != nil {
|
||||
return applyEnrollSecretSpecResponse{Err: err}, nil
|
||||
}
|
||||
return applyEnrollSecretSpecResponse{}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) ApplyEnrollSecretSpec(ctx context.Context, spec *fleet.EnrollSecretSpec) error {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.EnrollSecret{}, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, s := range spec.Secrets {
|
||||
if s.Secret == "" {
|
||||
return ctxerr.New(ctx, "enroll secret must not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
return svc.ds.ApplyEnrollSecrets(ctx, nil, spec.Secrets)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get enroll secret spec
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getEnrollSecretSpecResponse struct {
|
||||
Spec *fleet.EnrollSecretSpec `json:"spec"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getEnrollSecretSpecResponse) error() error { return r.Err }
|
||||
|
||||
func getEnrollSecretSpecEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
specs, err := svc.GetEnrollSecretSpec(ctx)
|
||||
if err != nil {
|
||||
return getEnrollSecretSpecResponse{Err: err}, nil
|
||||
}
|
||||
return getEnrollSecretSpecResponse{Spec: specs}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) GetEnrollSecretSpec(ctx context.Context) (*fleet.EnrollSecretSpec, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.EnrollSecret{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
secrets, err := svc.ds.GetEnrollSecrets(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fleet.EnrollSecretSpec{Secrets: secrets}, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Version
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type versionResponse struct {
|
||||
*version.Info
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r versionResponse) error() error { return r.Err }
|
||||
|
||||
func versionEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
info, err := svc.Version(ctx)
|
||||
if err != nil {
|
||||
return versionResponse{Err: err}, nil
|
||||
}
|
||||
return versionResponse{Info: info}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) Version(ctx context.Context) (*version.Info, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.AppConfig{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := version.Version()
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get Certificate Chain
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getCertificateResponse struct {
|
||||
CertificateChain []byte `json:"certificate_chain"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getCertificateResponse) error() error { return r.Err }
|
||||
|
||||
func getCertificateEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
chain, err := svc.CertificateChain(ctx)
|
||||
if err != nil {
|
||||
return getCertificateResponse{Err: err}, nil
|
||||
}
|
||||
return getCertificateResponse{CertificateChain: chain}, nil
|
||||
}
|
||||
|
||||
// Certificate returns the PEM encoded certificate chain for osqueryd TLS termination.
|
||||
func (svc *Service) CertificateChain(ctx context.Context) ([]byte, error) {
|
||||
config, err := svc.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(config.ServerSettings.ServerURL)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "parsing serverURL")
|
||||
}
|
||||
|
||||
conn, err := connectTLS(ctx, u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chain(ctx, conn.ConnectionState(), u.Hostname())
|
||||
}
|
||||
|
||||
func connectTLS(ctx context.Context, serverURL *url.URL) (*tls.Conn, error) {
|
||||
var hostport string
|
||||
if serverURL.Port() == "" {
|
||||
hostport = net.JoinHostPort(serverURL.Host, "443")
|
||||
} else {
|
||||
hostport = serverURL.Host
|
||||
}
|
||||
|
||||
// attempt dialing twice, first with a secure conn, and then
|
||||
// if that fails, use insecure
|
||||
dial := func(insecure bool) (*tls.Conn, error) {
|
||||
conn, err := tls.Dial("tcp", hostport, &tls.Config{
|
||||
InsecureSkipVerify: insecure})
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "dial tls")
|
||||
}
|
||||
defer conn.Close()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
var (
|
||||
conn *tls.Conn
|
||||
err error
|
||||
)
|
||||
|
||||
conn, err = dial(false)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
conn, err = dial(true)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// chain builds a PEM encoded certificate chain using the PeerCertificates
|
||||
// in tls.ConnectionState. chain uses the hostname to omit the Leaf certificate
|
||||
// from the chain.
|
||||
func chain(ctx context.Context, cs tls.ConnectionState, hostname string) ([]byte, error) {
|
||||
buf := bytes.NewBuffer([]byte(""))
|
||||
|
||||
verifyEncode := func(chain []*x509.Certificate) error {
|
||||
for _, cert := range chain {
|
||||
if len(chain) > 1 {
|
||||
// drop the leaf certificate from the chain. osqueryd does not
|
||||
// need it to establish a secure connection
|
||||
if err := cert.VerifyHostname(hostname); err == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := encodePEMCertificate(buf, cert); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// use verified chains if available(which adds the root CA), otherwise
|
||||
// use the certificate chain offered by the server (if terminated with
|
||||
// self-signed certs)
|
||||
if len(cs.VerifiedChains) != 0 {
|
||||
for _, chain := range cs.VerifiedChains {
|
||||
if err := verifyEncode(chain); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "encode verified chains pem")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := verifyEncode(cs.PeerCertificates); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "encode peer certificates pem")
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func encodePEMCertificate(buf io.Writer, cert *x509.Certificate) error {
|
||||
block := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}
|
||||
return pem.Encode(buf, block)
|
||||
}
|
262
server/service/appconfig_test.go
Normal file
262
server/service/appconfig_test.go
Normal file
@ -0,0 +1,262 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAppConfigAuth(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
// start a TLS server and use its URL as the server URL in the app config,
|
||||
// required by the CertificateChain service call.
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
defer srv.Close()
|
||||
|
||||
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
return &fleet.AppConfig{
|
||||
ServerSettings: fleet.ServerSettings{
|
||||
ServerURL: srv.URL,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
ds.SaveAppConfigFunc = func(ctx context.Context, conf *fleet.AppConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
user *fleet.User
|
||||
shouldFailWrite bool
|
||||
shouldFailRead bool
|
||||
}{
|
||||
{
|
||||
"global admin",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"global maintainer",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"global observer",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"team admin",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"team maintainer",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"team observer",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"user",
|
||||
&fleet.User{ID: 777},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
|
||||
|
||||
_, err := svc.AppConfig(ctx)
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
|
||||
_, err = svc.ModifyAppConfig(ctx, []byte(`{}`))
|
||||
checkAuthErr(t, tt.shouldFailWrite, err)
|
||||
|
||||
_, err = svc.Version(ctx)
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
|
||||
_, err = svc.CertificateChain(ctx)
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrollSecretAuth(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
ds.ApplyEnrollSecretsFunc = func(ctx context.Context, tid *uint, secrets []*fleet.EnrollSecret) error {
|
||||
return nil
|
||||
}
|
||||
ds.GetEnrollSecretsFunc = func(ctx context.Context, tid *uint) ([]*fleet.EnrollSecret, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
user *fleet.User
|
||||
shouldFailWrite bool
|
||||
shouldFailRead bool
|
||||
}{
|
||||
{
|
||||
"global admin",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"global maintainer",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleMaintainer)},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"global observer",
|
||||
&fleet.User{GlobalRole: ptr.String(fleet.RoleObserver)},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team admin",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleAdmin}}},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team maintainer",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleMaintainer}}},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"team observer",
|
||||
&fleet.User{Teams: []fleet.UserTeam{{Team: fleet.Team{ID: 1}, Role: fleet.RoleObserver}}},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"user",
|
||||
&fleet.User{ID: 777},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
|
||||
|
||||
err := svc.ApplyEnrollSecretSpec(ctx, &fleet.EnrollSecretSpec{Secrets: []*fleet.EnrollSecret{{Secret: "ABC"}}})
|
||||
checkAuthErr(t, tt.shouldFailWrite, err)
|
||||
|
||||
_, err = svc.GetEnrollSecretSpec(ctx)
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateChain(t *testing.T) {
|
||||
server, teardown := setupCertificateChain(t)
|
||||
defer teardown()
|
||||
|
||||
certFile := "testdata/server.pem"
|
||||
cert, err := tls.LoadX509KeyPair(certFile, "testdata/server.key")
|
||||
require.Nil(t, err)
|
||||
server.TLS = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
server.StartTLS()
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
require.Nil(t, err)
|
||||
|
||||
conn, err := connectTLS(context.Background(), u)
|
||||
require.Nil(t, err)
|
||||
|
||||
have, want := len(conn.ConnectionState().PeerCertificates), len(cert.Certificate)
|
||||
require.Equal(t, have, want)
|
||||
|
||||
original, _ := ioutil.ReadFile(certFile)
|
||||
returned, err := chain(context.Background(), conn.ConnectionState(), "")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, returned, original)
|
||||
}
|
||||
|
||||
func echoHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dump, err := httputil.DumpRequest(r, true)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(dump)
|
||||
})
|
||||
}
|
||||
|
||||
func setupCertificateChain(t *testing.T) (server *httptest.Server, teardown func()) {
|
||||
server = httptest.NewUnstartedServer(echoHandler())
|
||||
return server, server.Close
|
||||
}
|
||||
|
||||
func TestSSONotPresent(t *testing.T) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
var p fleet.AppConfig
|
||||
validateSSOSettings(p, &fleet.AppConfig{}, invalid)
|
||||
assert.False(t, invalid.HasErrors())
|
||||
|
||||
}
|
||||
|
||||
func TestNeedFieldsPresent(t *testing.T) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
config := fleet.AppConfig{
|
||||
SSOSettings: fleet.SSOSettings{
|
||||
EnableSSO: true,
|
||||
EntityID: "fleet",
|
||||
IssuerURI: "http://issuer.idp.com",
|
||||
MetadataURL: "http://isser.metadata.com",
|
||||
IDPName: "onelogin",
|
||||
},
|
||||
}
|
||||
validateSSOSettings(config, &fleet.AppConfig{}, invalid)
|
||||
assert.False(t, invalid.HasErrors())
|
||||
}
|
||||
|
||||
func TestMissingMetadata(t *testing.T) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
config := fleet.AppConfig{
|
||||
SSOSettings: fleet.SSOSettings{
|
||||
EnableSSO: true,
|
||||
EntityID: "fleet",
|
||||
IssuerURI: "http://issuer.idp.com",
|
||||
IDPName: "onelogin",
|
||||
},
|
||||
}
|
||||
validateSSOSettings(config, &fleet.AppConfig{}, invalid)
|
||||
require.True(t, invalid.HasErrors())
|
||||
assert.Contains(t, invalid.Error(), "metadata")
|
||||
assert.Contains(t, invalid.Error(), "either metadata or metadata_url must be defined")
|
||||
}
|
@ -1,191 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
"github.com/kolide/kit/version"
|
||||
)
|
||||
|
||||
type appConfigRequest struct {
|
||||
Payload json.RawMessage
|
||||
}
|
||||
|
||||
type appConfigResponse struct {
|
||||
fleet.AppConfig
|
||||
|
||||
UpdateInterval *fleet.UpdateIntervalConfig `json:"update_interval"`
|
||||
Vulnerabilities *fleet.VulnerabilitiesConfig `json:"vulnerabilities"`
|
||||
|
||||
// License is loaded from the service
|
||||
License *fleet.LicenseInfo `json:"license,omitempty"`
|
||||
// Logging is loaded on the fly rather than from the database.
|
||||
Logging *fleet.Logging `json:"logging,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r appConfigResponse) error() error { return r.Err }
|
||||
|
||||
func makeGetAppConfigEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not fetch user")
|
||||
}
|
||||
config, err := svc.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
license, err := svc.License(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loggingConfig, err := svc.LoggingConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updateIntervalConfig, err := svc.UpdateIntervalConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vulnConfig, err := svc.VulnerabilitiesConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var smtpSettings fleet.SMTPSettings
|
||||
var ssoSettings fleet.SSOSettings
|
||||
var hostExpirySettings fleet.HostExpirySettings
|
||||
var agentOptions *json.RawMessage
|
||||
// only admin can see smtp, sso, and host expiry settings
|
||||
if vc.User.GlobalRole != nil && *vc.User.GlobalRole == fleet.RoleAdmin {
|
||||
smtpSettings = config.SMTPSettings
|
||||
if smtpSettings.SMTPPassword != "" {
|
||||
smtpSettings.SMTPPassword = "********"
|
||||
}
|
||||
ssoSettings = config.SSOSettings
|
||||
hostExpirySettings = config.HostExpirySettings
|
||||
agentOptions = config.AgentOptions
|
||||
}
|
||||
hostSettings := config.HostSettings
|
||||
response := appConfigResponse{
|
||||
AppConfig: fleet.AppConfig{
|
||||
OrgInfo: config.OrgInfo,
|
||||
ServerSettings: config.ServerSettings,
|
||||
HostSettings: hostSettings,
|
||||
VulnerabilitySettings: config.VulnerabilitySettings,
|
||||
|
||||
SMTPSettings: smtpSettings,
|
||||
SSOSettings: ssoSettings,
|
||||
HostExpirySettings: hostExpirySettings,
|
||||
AgentOptions: agentOptions,
|
||||
|
||||
WebhookSettings: config.WebhookSettings,
|
||||
},
|
||||
UpdateInterval: updateIntervalConfig,
|
||||
Vulnerabilities: vulnConfig,
|
||||
License: license,
|
||||
Logging: loggingConfig,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
func makeModifyAppConfigEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(appConfigRequest)
|
||||
config, err := svc.ModifyAppConfig(ctx, req.Payload)
|
||||
if err != nil {
|
||||
return appConfigResponse{Err: err}, nil
|
||||
}
|
||||
license, err := svc.License(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loggingConfig, err := svc.LoggingConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response := appConfigResponse{
|
||||
AppConfig: *config,
|
||||
License: license,
|
||||
Logging: loggingConfig,
|
||||
}
|
||||
|
||||
if response.SMTPSettings.SMTPPassword != "" {
|
||||
response.SMTPSettings.SMTPPassword = "********"
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Apply enroll secret spec
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type applyEnrollSecretSpecRequest struct {
|
||||
Spec *fleet.EnrollSecretSpec `json:"spec"`
|
||||
}
|
||||
|
||||
type applyEnrollSecretSpecResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r applyEnrollSecretSpecResponse) error() error { return r.Err }
|
||||
|
||||
func makeApplyEnrollSecretSpecEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(applyEnrollSecretSpecRequest)
|
||||
err := svc.ApplyEnrollSecretSpec(ctx, req.Spec)
|
||||
if err != nil {
|
||||
return applyEnrollSecretSpecResponse{Err: err}, nil
|
||||
}
|
||||
return applyEnrollSecretSpecResponse{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get enroll secret spec
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getEnrollSecretSpecResponse struct {
|
||||
Spec *fleet.EnrollSecretSpec `json:"spec"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getEnrollSecretSpecResponse) error() error { return r.Err }
|
||||
|
||||
func makeGetEnrollSecretSpecEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
specs, err := svc.GetEnrollSecretSpec(ctx)
|
||||
if err != nil {
|
||||
return getEnrollSecretSpecResponse{Err: err}, nil
|
||||
}
|
||||
return getEnrollSecretSpecResponse{Spec: specs}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Version
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type versionResponse struct {
|
||||
*version.Info
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r versionResponse) error() error { return r.Err }
|
||||
|
||||
func makeVersionEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
info, err := svc.Version(ctx)
|
||||
if err != nil {
|
||||
return versionResponse{Err: err}, nil
|
||||
}
|
||||
return versionResponse{Info: info}, nil
|
||||
}
|
||||
}
|
@ -1,25 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
type certificateResponse struct {
|
||||
CertificateChain []byte `json:"certificate_chain"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r certificateResponse) error() error { return r.Err }
|
||||
|
||||
func makeCertificateEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
chain, err := svc.CertificateChain(ctx)
|
||||
if err != nil {
|
||||
return certificateResponse{Err: err}, nil
|
||||
}
|
||||
return certificateResponse{CertificateChain: chain}, nil
|
||||
}
|
||||
}
|
@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"html/template"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
@ -75,64 +74,6 @@ func makeLogoutEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get Info About Session
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getInfoAboutSessionRequest struct {
|
||||
ID uint
|
||||
}
|
||||
|
||||
type getInfoAboutSessionResponse struct {
|
||||
SessionID uint `json:"session_id"`
|
||||
UserID uint `json:"user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getInfoAboutSessionResponse) error() error { return r.Err }
|
||||
|
||||
func makeGetInfoAboutSessionEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(getInfoAboutSessionRequest)
|
||||
session, err := svc.GetInfoAboutSession(ctx, req.ID)
|
||||
if err != nil {
|
||||
return getInfoAboutSessionResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
return getInfoAboutSessionResponse{
|
||||
SessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Delete Session
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type deleteSessionRequest struct {
|
||||
ID uint
|
||||
}
|
||||
|
||||
type deleteSessionResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r deleteSessionResponse) error() error { return r.Err }
|
||||
|
||||
func makeDeleteSessionEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(deleteSessionRequest)
|
||||
err := svc.DeleteSession(ctx, req.ID)
|
||||
if err != nil {
|
||||
return deleteSessionResponse{Err: err}, nil
|
||||
}
|
||||
return deleteSessionResponse{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type initiateSSORequest struct {
|
||||
RelayURL string `json:"relay_url"`
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
@ -24,24 +23,6 @@ func makeCreateUserFromInviteEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
func makeGetSessionUserEndpoint(svc fleet.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
user, err := svc.AuthenticatedUser(ctx)
|
||||
if err != nil {
|
||||
return getUserResponse{Err: err}, nil
|
||||
}
|
||||
availableTeams, err := svc.ListAvailableTeamsForUser(ctx, user)
|
||||
if err != nil {
|
||||
if errors.Is(err, fleet.ErrMissingLicense) {
|
||||
availableTeams = []*fleet.TeamSummary{}
|
||||
} else {
|
||||
return getUserResponse{Err: err}, nil
|
||||
}
|
||||
}
|
||||
return getUserResponse{User: user, AvailableTeams: availableTeams}, nil
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Reset Password
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -27,15 +27,8 @@ type FleetEndpoints struct {
|
||||
Logout endpoint.Endpoint
|
||||
ForgotPassword endpoint.Endpoint
|
||||
ResetPassword endpoint.Endpoint
|
||||
Me endpoint.Endpoint
|
||||
CreateUserWithInvite endpoint.Endpoint
|
||||
PerformRequiredPasswordReset endpoint.Endpoint
|
||||
GetSessionInfo endpoint.Endpoint
|
||||
DeleteSession endpoint.Endpoint
|
||||
GetAppConfig endpoint.Endpoint
|
||||
ModifyAppConfig endpoint.Endpoint
|
||||
ApplyEnrollSecretSpec endpoint.Endpoint
|
||||
GetEnrollSecretSpec endpoint.Endpoint
|
||||
CreateInvite endpoint.Endpoint
|
||||
ListInvites endpoint.Endpoint
|
||||
DeleteInvite endpoint.Endpoint
|
||||
@ -60,14 +53,12 @@ type FleetEndpoints struct {
|
||||
CarveBegin endpoint.Endpoint
|
||||
CarveBlock endpoint.Endpoint
|
||||
SearchTargets endpoint.Endpoint
|
||||
GetCertificate endpoint.Endpoint
|
||||
ChangeEmail endpoint.Endpoint
|
||||
InitiateSSO endpoint.Endpoint
|
||||
CallbackSSO endpoint.Endpoint
|
||||
SSOSettings endpoint.Endpoint
|
||||
StatusResultStore endpoint.Endpoint
|
||||
StatusLiveQuery endpoint.Endpoint
|
||||
Version endpoint.Endpoint
|
||||
}
|
||||
|
||||
// MakeFleetServerEndpoints creates the Fleet API endpoints.
|
||||
@ -96,16 +87,10 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th
|
||||
PerformRequiredPasswordReset: logged(canPerformPasswordReset(makePerformRequiredPasswordResetEndpoint(svc))),
|
||||
|
||||
// Standard user authentication routes
|
||||
Me: authenticatedUser(svc, makeGetSessionUserEndpoint(svc)),
|
||||
GetSessionInfo: authenticatedUser(svc, makeGetInfoAboutSessionEndpoint(svc)),
|
||||
DeleteSession: authenticatedUser(svc, makeDeleteSessionEndpoint(svc)),
|
||||
GetAppConfig: authenticatedUser(svc, makeGetAppConfigEndpoint(svc)),
|
||||
ModifyAppConfig: authenticatedUser(svc, makeModifyAppConfigEndpoint(svc)),
|
||||
ApplyEnrollSecretSpec: authenticatedUser(svc, makeApplyEnrollSecretSpecEndpoint(svc)),
|
||||
GetEnrollSecretSpec: authenticatedUser(svc, makeGetEnrollSecretSpecEndpoint(svc)),
|
||||
CreateInvite: authenticatedUser(svc, makeCreateInviteEndpoint(svc)),
|
||||
ListInvites: authenticatedUser(svc, makeListInvitesEndpoint(svc)),
|
||||
DeleteInvite: authenticatedUser(svc, makeDeleteInviteEndpoint(svc)),
|
||||
CreateInvite: authenticatedUser(svc, makeCreateInviteEndpoint(svc)),
|
||||
ListInvites: authenticatedUser(svc, makeListInvitesEndpoint(svc)),
|
||||
DeleteInvite: authenticatedUser(svc, makeDeleteInviteEndpoint(svc)),
|
||||
|
||||
GetQuery: authenticatedUser(svc, makeGetQueryEndpoint(svc)),
|
||||
ListQueries: authenticatedUser(svc, makeListQueriesEndpoint(svc)),
|
||||
CreateQuery: authenticatedUser(svc, makeCreateQueryEndpoint(svc)),
|
||||
@ -119,9 +104,7 @@ func MakeFleetServerEndpoints(svc fleet.Service, urlPrefix string, limitStore th
|
||||
CreateDistributedQueryCampaign: authenticatedUser(svc, makeCreateDistributedQueryCampaignEndpoint(svc)),
|
||||
CreateDistributedQueryCampaignByNames: authenticatedUser(svc, makeCreateDistributedQueryCampaignByNamesEndpoint(svc)),
|
||||
SearchTargets: authenticatedUser(svc, makeSearchTargetsEndpoint(svc)),
|
||||
GetCertificate: authenticatedUser(svc, makeCertificateEndpoint(svc)),
|
||||
ChangeEmail: authenticatedUser(svc, makeChangeEmailEndpoint(svc)),
|
||||
Version: authenticatedUser(svc, makeVersionEndpoint(svc)),
|
||||
|
||||
// Authenticated status endpoints
|
||||
StatusResultStore: authenticatedUser(svc, makeStatusResultStoreEndpoint(svc)),
|
||||
@ -147,15 +130,8 @@ type fleetHandlers struct {
|
||||
Logout http.Handler
|
||||
ForgotPassword http.Handler
|
||||
ResetPassword http.Handler
|
||||
Me http.Handler
|
||||
CreateUserWithInvite http.Handler
|
||||
PerformRequiredPasswordReset http.Handler
|
||||
GetSessionInfo http.Handler
|
||||
DeleteSession http.Handler
|
||||
GetAppConfig http.Handler
|
||||
ModifyAppConfig http.Handler
|
||||
ApplyEnrollSecretSpec http.Handler
|
||||
GetEnrollSecretSpec http.Handler
|
||||
CreateInvite http.Handler
|
||||
ListInvites http.Handler
|
||||
DeleteInvite http.Handler
|
||||
@ -180,14 +156,12 @@ type fleetHandlers struct {
|
||||
CarveBegin http.Handler
|
||||
CarveBlock http.Handler
|
||||
SearchTargets http.Handler
|
||||
GetCertificate http.Handler
|
||||
ChangeEmail http.Handler
|
||||
InitiateSSO http.Handler
|
||||
CallbackSSO http.Handler
|
||||
SettingsSSO http.Handler
|
||||
StatusResultStore http.Handler
|
||||
StatusLiveQuery http.Handler
|
||||
Version http.Handler
|
||||
}
|
||||
|
||||
func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandlers {
|
||||
@ -200,15 +174,8 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle
|
||||
Logout: newServer(e.Logout, decodeNoParamsRequest),
|
||||
ForgotPassword: newServer(e.ForgotPassword, decodeForgotPasswordRequest),
|
||||
ResetPassword: newServer(e.ResetPassword, decodeResetPasswordRequest),
|
||||
Me: newServer(e.Me, decodeNoParamsRequest),
|
||||
CreateUserWithInvite: newServer(e.CreateUserWithInvite, decodeCreateUserRequest),
|
||||
PerformRequiredPasswordReset: newServer(e.PerformRequiredPasswordReset, decodePerformRequiredPasswordResetRequest),
|
||||
GetSessionInfo: newServer(e.GetSessionInfo, decodeGetInfoAboutSessionRequest),
|
||||
DeleteSession: newServer(e.DeleteSession, decodeDeleteSessionRequest),
|
||||
GetAppConfig: newServer(e.GetAppConfig, decodeNoParamsRequest),
|
||||
ModifyAppConfig: newServer(e.ModifyAppConfig, decodeModifyAppConfigRequest),
|
||||
ApplyEnrollSecretSpec: newServer(e.ApplyEnrollSecretSpec, decodeApplyEnrollSecretSpecRequest),
|
||||
GetEnrollSecretSpec: newServer(e.GetEnrollSecretSpec, decodeNoParamsRequest),
|
||||
CreateInvite: newServer(e.CreateInvite, decodeCreateInviteRequest),
|
||||
ListInvites: newServer(e.ListInvites, decodeListInvitesRequest),
|
||||
DeleteInvite: newServer(e.DeleteInvite, decodeDeleteInviteRequest),
|
||||
@ -233,14 +200,12 @@ func makeKitHandlers(e FleetEndpoints, opts []kithttp.ServerOption) *fleetHandle
|
||||
CarveBegin: newServer(e.CarveBegin, decodeCarveBeginRequest),
|
||||
CarveBlock: newServer(e.CarveBlock, decodeCarveBlockRequest),
|
||||
SearchTargets: newServer(e.SearchTargets, decodeSearchTargetsRequest),
|
||||
GetCertificate: newServer(e.GetCertificate, decodeNoParamsRequest),
|
||||
ChangeEmail: newServer(e.ChangeEmail, decodeChangeEmailRequest),
|
||||
InitiateSSO: newServer(e.InitiateSSO, decodeInitiateSSORequest),
|
||||
CallbackSSO: newServer(e.CallbackSSO, decodeCallbackSSORequest),
|
||||
SettingsSSO: newServer(e.SSOSettings, decodeNoParamsRequest),
|
||||
StatusResultStore: newServer(e.StatusResultStore, decodeNoParamsRequest),
|
||||
StatusLiveQuery: newServer(e.StatusLiveQuery, decodeNoParamsRequest),
|
||||
Version: newServer(e.Version, decodeNoParamsRequest),
|
||||
}
|
||||
}
|
||||
|
||||
@ -415,7 +380,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
||||
r.Handle("/api/v1/fleet/logout", h.Logout).Methods("POST").Name("logout")
|
||||
r.Handle("/api/v1/fleet/forgot_password", h.ForgotPassword).Methods("POST").Name("forgot_password")
|
||||
r.Handle("/api/v1/fleet/reset_password", h.ResetPassword).Methods("POST").Name("reset_password")
|
||||
r.Handle("/api/v1/fleet/me", h.Me).Methods("GET").Name("me")
|
||||
r.Handle("/api/v1/fleet/perform_required_password_reset", h.PerformRequiredPasswordReset).Methods("POST").Name("perform_required_password_reset")
|
||||
r.Handle("/api/v1/fleet/sso", h.InitiateSSO).Methods("POST").Name("intiate_sso")
|
||||
r.Handle("/api/v1/fleet/sso", h.SettingsSSO).Methods("GET").Name("sso_config")
|
||||
@ -423,14 +387,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
||||
|
||||
r.Handle("/api/v1/fleet/users", h.CreateUserWithInvite).Methods("POST").Name("create_user_with_invite")
|
||||
|
||||
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.GetSessionInfo).Methods("GET").Name("get_session_info")
|
||||
r.Handle("/api/v1/fleet/sessions/{id:[0-9]+}", h.DeleteSession).Methods("DELETE").Name("delete_session")
|
||||
|
||||
r.Handle("/api/v1/fleet/config/certificate", h.GetCertificate).Methods("GET").Name("get_certificate")
|
||||
r.Handle("/api/v1/fleet/config", h.GetAppConfig).Methods("GET").Name("get_app_config")
|
||||
r.Handle("/api/v1/fleet/config", h.ModifyAppConfig).Methods("PATCH").Name("modify_app_config")
|
||||
r.Handle("/api/v1/fleet/spec/enroll_secret", h.ApplyEnrollSecretSpec).Methods("POST").Name("apply_enroll_secret_spec")
|
||||
r.Handle("/api/v1/fleet/spec/enroll_secret", h.GetEnrollSecretSpec).Methods("GET").Name("get_enroll_secret_spec")
|
||||
r.Handle("/api/v1/fleet/invites", h.CreateInvite).Methods("POST").Name("create_invite")
|
||||
r.Handle("/api/v1/fleet/invites", h.ListInvites).Methods("GET").Name("list_invites")
|
||||
r.Handle("/api/v1/fleet/invites/{id:[0-9]+}", h.DeleteInvite).Methods("DELETE").Name("delete_invite")
|
||||
@ -453,8 +409,6 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
||||
|
||||
r.Handle("/api/v1/fleet/targets", h.SearchTargets).Methods("POST").Name("search_targets")
|
||||
|
||||
r.Handle("/api/v1/fleet/version", h.Version).Methods("GET").Name("version")
|
||||
|
||||
r.Handle("/api/v1/fleet/status/result_store", h.StatusResultStore).Methods("GET").Name("status_result_store")
|
||||
r.Handle("/api/v1/fleet/status/live_query", h.StatusLiveQuery).Methods("GET").Name("status_live_query")
|
||||
|
||||
@ -470,6 +424,17 @@ func attachFleetAPIRoutes(r *mux.Router, h *fleetHandlers) {
|
||||
func attachNewStyleFleetAPIRoutes(r *mux.Router, svc fleet.Service, opts []kithttp.ServerOption) {
|
||||
e := NewUserAuthenticatedEndpointer(svc, opts, r, "v1")
|
||||
|
||||
e.GET("/api/_version_/fleet/me", meEndpoint, nil)
|
||||
e.GET("/api/_version_/fleet/sessions/{id:[0-9]+}", getInfoAboutSessionEndpoint, getInfoAboutSessionRequest{})
|
||||
e.DELETE("/api/_version_/fleet/sessions/{id:[0-9]+}", deleteSessionEndpoint, deleteSessionRequest{})
|
||||
|
||||
e.GET("/api/_version_/fleet/config/certificate", getCertificateEndpoint, nil)
|
||||
e.GET("/api/_version_/fleet/config", getAppConfigEndpoint, nil)
|
||||
e.PATCH("/api/_version_/fleet/config", modifyAppConfigEndpoint, modifyAppConfigRequest{})
|
||||
e.POST("/api/_version_/fleet/spec/enroll_secret", applyEnrollSecretSpecEndpoint, applyEnrollSecretSpecRequest{})
|
||||
e.GET("/api/_version_/fleet/spec/enroll_secret", getEnrollSecretSpecEndpoint, nil)
|
||||
e.GET("/api/_version_/fleet/version", versionEndpoint, nil)
|
||||
|
||||
e.POST("/api/_version_/fleet/users/roles/spec", applyUserRoleSpecsEndpoint, applyUserRoleSpecsRequest{})
|
||||
e.POST("/api/_version_/fleet/translate", translatorEndpoint, translatorRequest{})
|
||||
e.POST("/api/_version_/fleet/spec/teams", applyTeamSpecsEndpoint, applyTeamSpecsRequest{})
|
||||
|
@ -54,18 +54,6 @@ func TestAPIRoutes(t *testing.T) {
|
||||
verb: "POST",
|
||||
uri: "/api/v1/fleet/reset_password",
|
||||
},
|
||||
{
|
||||
verb: "GET",
|
||||
uri: "/api/v1/fleet/me",
|
||||
},
|
||||
{
|
||||
verb: "GET",
|
||||
uri: "/api/v1/fleet/config",
|
||||
},
|
||||
{
|
||||
verb: "PATCH",
|
||||
uri: "/api/v1/fleet/config",
|
||||
},
|
||||
{
|
||||
verb: "GET",
|
||||
uri: "/api/v1/fleet/invites",
|
||||
|
@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -1936,19 +1937,11 @@ func (s *integrationTestSuite) TestUsers() {
|
||||
|
||||
// test available teams returned by `/me` endpoint for existing user
|
||||
var getMeResp getUserResponse
|
||||
key := make([]byte, 64)
|
||||
sessionKey := base64.StdEncoding.EncodeToString(key)
|
||||
session := &fleet.Session{
|
||||
UserID: uint(1),
|
||||
Key: sessionKey,
|
||||
AccessedAt: time.Now().UTC(),
|
||||
}
|
||||
_, err := s.ds.NewSession(context.Background(), session)
|
||||
require.NoError(t, err)
|
||||
ssn := createSession(t, 1, s.ds)
|
||||
resp := s.DoRawWithHeaders("GET", "/api/v1/fleet/me", []byte(""), http.StatusOK, map[string]string{
|
||||
"Authorization": fmt.Sprintf("Bearer %s", sessionKey),
|
||||
"Authorization": fmt.Sprintf("Bearer %s", ssn.Key),
|
||||
})
|
||||
err = json.NewDecoder(resp.Body).Decode(&getMeResp)
|
||||
err := json.NewDecoder(resp.Body).Decode(&getMeResp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint(1), getMeResp.User.ID)
|
||||
assert.NotNil(t, getMeResp.User.GlobalRole)
|
||||
@ -2267,3 +2260,102 @@ func (s *integrationTestSuite) TestTeamPoliciesTeamNotExists() {
|
||||
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/teams/%d/policies", 9999999), nil, http.StatusNotFound, &teamPoliciesResponse)
|
||||
require.Len(t, teamPoliciesResponse.Policies, 0)
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestSessionInfo() {
|
||||
t := s.T()
|
||||
|
||||
ssn := createSession(t, 1, s.ds)
|
||||
|
||||
var meResp getUserResponse
|
||||
resp := s.DoRawWithHeaders("GET", "/api/v1/fleet/me", nil, http.StatusOK, map[string]string{
|
||||
"Authorization": fmt.Sprintf("Bearer %s", ssn.Key),
|
||||
})
|
||||
require.NoError(t, json.NewDecoder(resp.Body).Decode(&meResp))
|
||||
assert.Equal(t, uint(1), meResp.User.ID)
|
||||
|
||||
// get info about session
|
||||
var getResp getInfoAboutSessionResponse
|
||||
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/sessions/%d", ssn.ID), nil, http.StatusOK, &getResp)
|
||||
assert.Equal(t, ssn.ID, getResp.SessionID)
|
||||
assert.Equal(t, uint(1), getResp.UserID)
|
||||
|
||||
// get info about session - non-existing: appears to deliberately return 500 due to forbidden,
|
||||
// which takes precedence vs the not found returned by the datastore (it still shouldn't be a
|
||||
// 500 though).
|
||||
s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/sessions/%d", ssn.ID+1), nil, http.StatusInternalServerError, &getResp)
|
||||
|
||||
// delete session
|
||||
var delResp deleteSessionResponse
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/sessions/%d", ssn.ID), nil, http.StatusOK, &delResp)
|
||||
|
||||
// delete session - non-existing: again, 500 due to forbidden instead of 404.
|
||||
s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/sessions/%d", ssn.ID), nil, http.StatusInternalServerError, &delResp)
|
||||
}
|
||||
|
||||
func (s *integrationTestSuite) TestAppConfig() {
|
||||
t := s.T()
|
||||
|
||||
// get the app config
|
||||
var acResp appConfigResponse
|
||||
s.DoJSON("GET", "/api/v1/fleet/config", nil, http.StatusOK, &acResp)
|
||||
assert.Equal(t, "free", acResp.License.Tier)
|
||||
assert.Equal(t, "", acResp.OrgInfo.OrgName)
|
||||
|
||||
// no server settings set for the URL, so not possible to test the
|
||||
// certificate endpoint
|
||||
acResp = appConfigResponse{}
|
||||
s.DoJSON("PATCH", "/api/v1/fleet/config", json.RawMessage(`{
|
||||
"org_info": {
|
||||
"org_name": "test"
|
||||
}
|
||||
}`), http.StatusOK, &acResp)
|
||||
assert.Equal(t, "test", acResp.OrgInfo.OrgName)
|
||||
|
||||
var verResp versionResponse
|
||||
s.DoJSON("GET", "/api/v1/fleet/version", nil, http.StatusOK, &verResp)
|
||||
assert.NotEmpty(t, verResp.Branch)
|
||||
|
||||
// get enroll secrets, none yet
|
||||
var specResp getEnrollSecretSpecResponse
|
||||
s.DoJSON("GET", "/api/v1/fleet/spec/enroll_secret", nil, http.StatusOK, &specResp)
|
||||
assert.Empty(t, specResp.Spec.Secrets)
|
||||
|
||||
// apply spec, one secret
|
||||
var applyResp applyEnrollSecretSpecResponse
|
||||
s.DoJSON("POST", "/api/v1/fleet/spec/enroll_secret", applyEnrollSecretSpecRequest{
|
||||
Spec: &fleet.EnrollSecretSpec{
|
||||
Secrets: []*fleet.EnrollSecret{{Secret: "XYZ"}},
|
||||
},
|
||||
}, http.StatusOK, &applyResp)
|
||||
|
||||
// get enroll secrets, one
|
||||
s.DoJSON("GET", "/api/v1/fleet/spec/enroll_secret", nil, http.StatusOK, &specResp)
|
||||
require.Len(t, specResp.Spec.Secrets, 1)
|
||||
assert.Equal(t, "XYZ", specResp.Spec.Secrets[0].Secret)
|
||||
|
||||
// remove secret just to prevent affecting other tests
|
||||
s.DoJSON("POST", "/api/v1/fleet/spec/enroll_secret", applyEnrollSecretSpecRequest{
|
||||
Spec: &fleet.EnrollSecretSpec{},
|
||||
}, http.StatusOK, &applyResp)
|
||||
|
||||
s.DoJSON("GET", "/api/v1/fleet/spec/enroll_secret", nil, http.StatusOK, &specResp)
|
||||
require.Len(t, specResp.Spec.Secrets, 0)
|
||||
}
|
||||
|
||||
// creates a session and returns it, its key is to be passed as authorization header.
|
||||
func createSession(t *testing.T, uid uint, ds fleet.Datastore) *fleet.Session {
|
||||
key := make([]byte, 64)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionKey := base64.StdEncoding.EncodeToString(key)
|
||||
session := &fleet.Session{
|
||||
UserID: uid,
|
||||
Key: sessionKey,
|
||||
AccessedAt: time.Now().UTC(),
|
||||
}
|
||||
ssn, err := ds.NewSession(context.Background(), session)
|
||||
require.NoError(t, err)
|
||||
|
||||
return ssn
|
||||
}
|
||||
|
@ -1,9 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"strings"
|
||||
@ -13,7 +11,6 @@ import (
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mail"
|
||||
"github.com/kolide/kit/version"
|
||||
)
|
||||
|
||||
// mailError is set when an error performing mail operations
|
||||
@ -61,14 +58,6 @@ func (svc *Service) NewAppConfig(ctx context.Context, p fleet.AppConfig) (*fleet
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
func (svc *Service) AppConfig(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.AppConfig{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return svc.ds.AppConfig(ctx)
|
||||
}
|
||||
|
||||
func (svc *Service) sendTestEmail(ctx context.Context, config *fleet.AppConfig) error {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
@ -91,77 +80,10 @@ func (svc *Service) sendTestEmail(ctx context.Context, config *fleet.AppConfig)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte) (*fleet.AppConfig, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.AppConfig{}, fleet.ActionWrite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
appConfig, err := svc.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We apply the config that is incoming to the old one
|
||||
decoder := json.NewDecoder(bytes.NewReader(p))
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(&appConfig); err != nil {
|
||||
return nil, &badRequestError{message: err.Error()}
|
||||
}
|
||||
|
||||
if appConfig.SMTPSettings.SMTPEnabled || appConfig.SMTPSettings.SMTPConfigured {
|
||||
if err = svc.sendTestEmail(ctx, appConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appConfig.SMTPSettings.SMTPConfigured = true
|
||||
} else if appConfig.SMTPSettings.SMTPEnabled {
|
||||
appConfig.SMTPSettings.SMTPConfigured = false
|
||||
}
|
||||
|
||||
if err := svc.ds.SaveAppConfig(ctx, appConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return appConfig, nil
|
||||
}
|
||||
|
||||
func cleanupURL(url string) string {
|
||||
return strings.TrimRight(strings.Trim(url, " \t\n"), "/")
|
||||
}
|
||||
|
||||
func (svc *Service) ApplyEnrollSecretSpec(ctx context.Context, spec *fleet.EnrollSecretSpec) error {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.EnrollSecret{}, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, s := range spec.Secrets {
|
||||
if s.Secret == "" {
|
||||
return ctxerr.New(ctx, "enroll secret must not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
return svc.ds.ApplyEnrollSecrets(ctx, nil, spec.Secrets)
|
||||
}
|
||||
|
||||
func (svc *Service) GetEnrollSecretSpec(ctx context.Context) (*fleet.EnrollSecretSpec, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.EnrollSecret{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
secrets, err := svc.ds.GetEnrollSecrets(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fleet.EnrollSecretSpec{Secrets: secrets}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) Version(ctx context.Context) (*version.Info, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.AppConfig{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := version.Version()
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
func (svc *Service) License(ctx context.Context) (*fleet.LicenseInfo, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.AppConfig{}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
|
@ -1,114 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
)
|
||||
|
||||
// Certificate returns the PEM encoded certificate chain for osqueryd TLS termination.
|
||||
func (svc *Service) CertificateChain(ctx context.Context) ([]byte, error) {
|
||||
config, err := svc.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(config.ServerSettings.ServerURL)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "parsing serverURL")
|
||||
}
|
||||
|
||||
conn, err := connectTLS(ctx, u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chain(ctx, conn.ConnectionState(), u.Hostname())
|
||||
}
|
||||
|
||||
func connectTLS(ctx context.Context, serverURL *url.URL) (*tls.Conn, error) {
|
||||
var hostport string
|
||||
if serverURL.Port() == "" {
|
||||
hostport = net.JoinHostPort(serverURL.Host, "443")
|
||||
} else {
|
||||
hostport = serverURL.Host
|
||||
}
|
||||
|
||||
// attempt dialing twice, first with a secure conn, and then
|
||||
// if that fails, use insecure
|
||||
dial := func(insecure bool) (*tls.Conn, error) {
|
||||
conn, err := tls.Dial("tcp", hostport, &tls.Config{
|
||||
InsecureSkipVerify: insecure})
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "dial tls")
|
||||
}
|
||||
defer conn.Close()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
var (
|
||||
conn *tls.Conn
|
||||
err error
|
||||
)
|
||||
|
||||
conn, err = dial(false)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
conn, err = dial(true)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// chain builds a PEM encoded certificate chain using the PeerCertificates
|
||||
// in tls.ConnectionState. chain uses the hostname to omit the Leaf certificate
|
||||
// from the chain.
|
||||
func chain(ctx context.Context, cs tls.ConnectionState, hostname string) ([]byte, error) {
|
||||
buf := bytes.NewBuffer([]byte(""))
|
||||
|
||||
verifyEncode := func(chain []*x509.Certificate) error {
|
||||
for _, cert := range chain {
|
||||
if len(chain) > 1 {
|
||||
// drop the leaf certificate from the chain. osqueryd does not
|
||||
// need it to establish a secure connection
|
||||
if err := cert.VerifyHostname(hostname); err == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := encodePEMCertificate(buf, cert); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// use verified chains if available(which adds the root CA), otherwise
|
||||
// use the certificate chain offered by the server (if terminated with
|
||||
// self-signed certs)
|
||||
if len(cs.VerifiedChains) != 0 {
|
||||
for _, chain := range cs.VerifiedChains {
|
||||
if err := verifyEncode(chain); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "encode verified chains pem")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := verifyEncode(cs.PeerCertificates); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "encode peer certificates pem")
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func encodePEMCertificate(buf io.Writer, cert *x509.Certificate) error {
|
||||
block := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}
|
||||
return pem.Encode(buf, block)
|
||||
}
|
@ -1,57 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCertificateChain(t *testing.T) {
|
||||
server, teardown := setupCertificateChain(t)
|
||||
defer teardown()
|
||||
|
||||
certFile := "testdata/server.pem"
|
||||
cert, err := tls.LoadX509KeyPair(certFile, "testdata/server.key")
|
||||
require.Nil(t, err)
|
||||
server.TLS = &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
server.StartTLS()
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
require.Nil(t, err)
|
||||
|
||||
conn, err := connectTLS(context.Background(), u)
|
||||
require.Nil(t, err)
|
||||
|
||||
have, want := len(conn.ConnectionState().PeerCertificates), len(cert.Certificate)
|
||||
require.Equal(t, have, want)
|
||||
|
||||
original, _ := ioutil.ReadFile(certFile)
|
||||
returned, err := chain(context.Background(), conn.ConnectionState(), "")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, returned, original)
|
||||
}
|
||||
|
||||
func echoHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
dump, err := httputil.DumpRequest(r, true)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(dump)
|
||||
})
|
||||
}
|
||||
|
||||
func setupCertificateChain(t *testing.T) (server *httptest.Server, teardown func()) {
|
||||
server = httptest.NewUnstartedServer(echoHandler())
|
||||
return server, server.Close
|
||||
}
|
@ -284,24 +284,6 @@ func (svc *Service) DestroySession(ctx context.Context) error {
|
||||
return svc.ds.DestroySession(ctx, session)
|
||||
}
|
||||
|
||||
func (svc *Service) GetInfoAboutSession(ctx context.Context, id uint) (*fleet.Session, error) {
|
||||
session, err := svc.ds.SessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = svc.validateSession(ctx, session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (svc *Service) GetSessionByKey(ctx context.Context, key string) (*fleet.Session, error) {
|
||||
session, err := svc.ds.SessionByKey(ctx, key)
|
||||
if err != nil {
|
||||
@ -316,19 +298,6 @@ func (svc *Service) GetSessionByKey(ctx context.Context, key string) (*fleet.Ses
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (svc *Service) DeleteSession(ctx context.Context, id uint) error {
|
||||
session, err := svc.ds.SessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, session, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return svc.ds.DestroySession(ctx, session)
|
||||
}
|
||||
|
||||
func (svc *Service) validateSession(ctx context.Context, session *fleet.Session) error {
|
||||
if session == nil {
|
||||
return fleet.NewAuthRequiredError("active session not present")
|
||||
|
@ -102,22 +102,6 @@ func (svc *Service) UserUnauthorized(ctx context.Context, id uint) (*fleet.User,
|
||||
return svc.ds.UserByID(ctx, id)
|
||||
}
|
||||
|
||||
func (svc *Service) AuthenticatedUser(ctx context.Context) (*fleet.User, error) {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fleet.ErrNoContext
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{ID: vc.UserID()}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !vc.IsLoggedIn() {
|
||||
return nil, fleet.NewPermissionError("not logged in")
|
||||
}
|
||||
return vc.User, nil
|
||||
}
|
||||
|
||||
// setNewPassword is a helper for changing a user's password. It should be
|
||||
// called to set the new password after proper authorization has been
|
||||
// performed.
|
||||
|
93
server/service/sessions.go
Normal file
93
server/service/sessions.go
Normal file
@ -0,0 +1,93 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get Info About Session
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type getInfoAboutSessionRequest struct {
|
||||
ID uint `url:"id"`
|
||||
}
|
||||
|
||||
type getInfoAboutSessionResponse struct {
|
||||
SessionID uint `json:"session_id"`
|
||||
UserID uint `json:"user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r getInfoAboutSessionResponse) error() error { return r.Err }
|
||||
|
||||
func getInfoAboutSessionEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*getInfoAboutSessionRequest)
|
||||
session, err := svc.GetInfoAboutSession(ctx, req.ID)
|
||||
if err != nil {
|
||||
return getInfoAboutSessionResponse{Err: err}, nil
|
||||
}
|
||||
|
||||
return getInfoAboutSessionResponse{
|
||||
SessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) GetInfoAboutSession(ctx context.Context, id uint) (*fleet.Session, error) {
|
||||
session, err := svc.ds.SessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, session, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = svc.validateSession(ctx, session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Delete Session
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type deleteSessionRequest struct {
|
||||
ID uint `url:"id"`
|
||||
}
|
||||
|
||||
type deleteSessionResponse struct {
|
||||
Err error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r deleteSessionResponse) error() error { return r.Err }
|
||||
|
||||
func deleteSessionEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
req := request.(*deleteSessionRequest)
|
||||
err := svc.DeleteSession(ctx, req.ID)
|
||||
if err != nil {
|
||||
return deleteSessionResponse{Err: err}, nil
|
||||
}
|
||||
return deleteSessionResponse{}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) DeleteSession(ctx context.Context, id uint) error {
|
||||
session, err := svc.ds.SessionByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, session, fleet.ActionWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return svc.ds.DestroySession(ctx, session)
|
||||
}
|
@ -1,8 +1,87 @@
|
||||
package service
|
||||
|
||||
// TODO(mna): when migrating Session-related endpoints, add auth tests for those
|
||||
// endpoints (the auth is session-based, not user-based).
|
||||
//_, err = svc.GetInfoAboutSessionsForUser(ctx, 999)
|
||||
//checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
||||
//_, err = svc.GetInfoAboutSessionsForUser(ctx, 888)
|
||||
//checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mock"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
)
|
||||
|
||||
func TestSessionAuth(t *testing.T) {
|
||||
ds := new(mock.Store)
|
||||
svc := newTestService(ds, nil, nil)
|
||||
|
||||
ds.ListSessionsForUserFunc = func(ctx context.Context, id uint) ([]*fleet.Session, error) {
|
||||
if id == 999 {
|
||||
return []*fleet.Session{
|
||||
{ID: 1, UserID: id, AccessedAt: time.Now()},
|
||||
}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
ds.SessionByIDFunc = func(ctx context.Context, id uint) (*fleet.Session, error) {
|
||||
return &fleet.Session{ID: id, UserID: 999, AccessedAt: time.Now()}, nil
|
||||
}
|
||||
ds.DestroySessionFunc = func(ctx context.Context, ssn *fleet.Session) error {
|
||||
return nil
|
||||
}
|
||||
ds.MarkSessionAccessedFunc = func(ctx context.Context, ssn *fleet.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
user *fleet.User
|
||||
shouldFailWrite bool
|
||||
shouldFailRead bool
|
||||
}{
|
||||
{
|
||||
"global admin",
|
||||
&fleet.User{ID: 111, GlobalRole: ptr.String(fleet.RoleAdmin)},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"global maintainer",
|
||||
&fleet.User{ID: 111, GlobalRole: ptr.String(fleet.RoleMaintainer)},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"global observer",
|
||||
&fleet.User{ID: 111, GlobalRole: ptr.String(fleet.RoleObserver)},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"owner user",
|
||||
&fleet.User{ID: 999},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"non-owner user",
|
||||
&fleet.User{ID: 888},
|
||||
true,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := viewer.NewContext(context.Background(), viewer.Viewer{User: tt.user})
|
||||
|
||||
_, err := svc.GetInfoAboutSessionsForUser(ctx, 999)
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
|
||||
_, err = svc.GetInfoAboutSession(ctx, 1)
|
||||
checkAuthErr(t, tt.shouldFailRead, err)
|
||||
|
||||
err = svc.DeleteSession(ctx, 1)
|
||||
checkAuthErr(t, tt.shouldFailWrite, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,25 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func decodeModifyAppConfigRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
payload, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return appConfigRequest{Payload: payload}, nil
|
||||
}
|
||||
|
||||
func decodeApplyEnrollSecretSpecRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req applyEnrollSecretSpecRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req, nil
|
||||
|
||||
}
|
@ -10,22 +10,6 @@ import (
|
||||
"github.com/fleetdm/fleet/v4/server/sso"
|
||||
)
|
||||
|
||||
func decodeGetInfoAboutSessionRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getInfoAboutSessionRequest{ID: uint(id)}, nil
|
||||
}
|
||||
|
||||
func decodeDeleteSessionRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
id, err := uintFromRequest(r, "id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return deleteSessionRequest{ID: uint(id)}, nil
|
||||
}
|
||||
|
||||
func decodeLoginRequest(ctx context.Context, r *http.Request) (interface{}, error) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
|
@ -11,38 +11,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecodeGetInfoAboutSessionRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/sessions/{id}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeGetInfoAboutSessionRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(getInfoAboutSessionRequest)
|
||||
assert.Equal(t, uint(1), params.ID)
|
||||
}).Methods("GET")
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("GET", "/api/v1/fleet/sessions/1", nil),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDecodeDeleteSessionRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/sessions/{id}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
r, err := decodeDeleteSessionRequest(context.Background(), request)
|
||||
assert.Nil(t, err)
|
||||
|
||||
params := r.(deleteSessionRequest)
|
||||
assert.Equal(t, uint(1), params.ID)
|
||||
}).Methods("DELETE")
|
||||
|
||||
router.ServeHTTP(
|
||||
httptest.NewRecorder(),
|
||||
httptest.NewRequest("DELETE", "/api/v1/fleet/sessions/1", nil),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDecodeLoginRequest(t *testing.T) {
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/v1/fleet/login", func(writer http.ResponseWriter, request *http.Request) {
|
||||
|
@ -97,6 +97,42 @@ func (svc *Service) ListUsers(ctx context.Context, opt fleet.UserListOptions) ([
|
||||
return svc.ds.ListUsers(ctx, opt)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Me (get own current user)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func meEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
|
||||
user, err := svc.AuthenticatedUser(ctx)
|
||||
if err != nil {
|
||||
return getUserResponse{Err: err}, nil
|
||||
}
|
||||
availableTeams, err := svc.ListAvailableTeamsForUser(ctx, user)
|
||||
if err != nil {
|
||||
if errors.Is(err, fleet.ErrMissingLicense) {
|
||||
availableTeams = []*fleet.TeamSummary{}
|
||||
} else {
|
||||
return getUserResponse{Err: err}, nil
|
||||
}
|
||||
}
|
||||
return getUserResponse{User: user, AvailableTeams: availableTeams}, nil
|
||||
}
|
||||
|
||||
func (svc *Service) AuthenticatedUser(ctx context.Context) (*fleet.User, error) {
|
||||
vc, ok := viewer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fleet.ErrNoContext
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, &fleet.User{ID: vc.UserID()}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !vc.IsLoggedIn() {
|
||||
return nil, fleet.NewPermissionError("not logged in")
|
||||
}
|
||||
return vc.User, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Get User
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -399,7 +435,7 @@ func getInfoAboutSessionsForUserEndpoint(ctx context.Context, request interface{
|
||||
}
|
||||
|
||||
func (svc *Service) GetInfoAboutSessionsForUser(ctx context.Context, id uint) ([]*fleet.Session, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionWrite); err != nil {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Session{UserID: id}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -1,58 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
)
|
||||
|
||||
func (mw validationMiddleware) ModifyAppConfig(ctx context.Context, p []byte) (*fleet.AppConfig, error) {
|
||||
existing, err := mw.ds.AppConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "fetching existing app config in validation")
|
||||
}
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
var appConfig fleet.AppConfig
|
||||
err = json.Unmarshal(p, &appConfig)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err)
|
||||
}
|
||||
validateSSOSettings(appConfig, existing, invalid)
|
||||
if invalid.HasErrors() {
|
||||
return nil, ctxerr.Wrap(ctx, invalid)
|
||||
}
|
||||
return mw.Service.ModifyAppConfig(ctx, p)
|
||||
}
|
||||
|
||||
func validateSSOSettings(p fleet.AppConfig, existing *fleet.AppConfig, invalid *fleet.InvalidArgumentError) {
|
||||
if p.SSOSettings.EnableSSO {
|
||||
if p.SSOSettings.Metadata == "" && p.SSOSettings.MetadataURL == "" {
|
||||
if existing.SSOSettings.Metadata == "" && existing.SSOSettings.MetadataURL == "" {
|
||||
invalid.Append("metadata", "either metadata or metadata_url must be defined")
|
||||
}
|
||||
}
|
||||
if p.SSOSettings.Metadata != "" && p.SSOSettings.MetadataURL != "" {
|
||||
invalid.Append("metadata", "both metadata and metadata_url are defined, only one is allowed")
|
||||
}
|
||||
if p.SSOSettings.EntityID == "" {
|
||||
if existing.SSOSettings.EntityID == "" {
|
||||
invalid.Append("entity_id", "required")
|
||||
}
|
||||
} else {
|
||||
if len(p.SSOSettings.EntityID) < 5 {
|
||||
invalid.Append("entity_id", "must be 5 or more characters")
|
||||
}
|
||||
}
|
||||
if p.SSOSettings.IDPName == "" {
|
||||
if existing.SSOSettings.IDPName == "" {
|
||||
invalid.Append("idp_name", "required")
|
||||
}
|
||||
} else {
|
||||
if len(p.SSOSettings.IDPName) < 4 {
|
||||
invalid.Append("idp_name", "must be 4 or more characters")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,48 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSONotPresent(t *testing.T) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
var p fleet.AppConfig
|
||||
validateSSOSettings(p, &fleet.AppConfig{}, invalid)
|
||||
assert.False(t, invalid.HasErrors())
|
||||
|
||||
}
|
||||
|
||||
func TestNeedFieldsPresent(t *testing.T) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
config := fleet.AppConfig{
|
||||
SSOSettings: fleet.SSOSettings{
|
||||
EnableSSO: true,
|
||||
EntityID: "fleet",
|
||||
IssuerURI: "http://issuer.idp.com",
|
||||
MetadataURL: "http://isser.metadata.com",
|
||||
IDPName: "onelogin",
|
||||
},
|
||||
}
|
||||
validateSSOSettings(config, &fleet.AppConfig{}, invalid)
|
||||
assert.False(t, invalid.HasErrors())
|
||||
}
|
||||
|
||||
func TestMissingMetadata(t *testing.T) {
|
||||
invalid := &fleet.InvalidArgumentError{}
|
||||
config := fleet.AppConfig{
|
||||
SSOSettings: fleet.SSOSettings{
|
||||
EnableSSO: true,
|
||||
EntityID: "fleet",
|
||||
IssuerURI: "http://issuer.idp.com",
|
||||
IDPName: "onelogin",
|
||||
},
|
||||
}
|
||||
validateSSOSettings(config, &fleet.AppConfig{}, invalid)
|
||||
require.True(t, invalid.HasErrors())
|
||||
assert.Contains(t, invalid.Error(), "metadata")
|
||||
assert.Contains(t, invalid.Error(), "either metadata or metadata_url must be defined")
|
||||
}
|
Loading…
Reference in New Issue
Block a user