mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
a23d208b1d
#10739 Co-authored-by: Gabriel Hernandez <ghernandez345@gmail.com> Co-authored-by: gillespi314 <73313222+gillespi314@users.noreply.github.com>
241 lines
6.7 KiB
Go
241 lines
6.7 KiB
Go
package sso
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/xml"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/beevik/etree"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
rtvalidator "github.com/mattermost/xml-roundtrip-validator"
|
|
dsig "github.com/russellhaering/goxmldsig"
|
|
"github.com/russellhaering/goxmldsig/etreeutils"
|
|
)
|
|
|
|
type Validator interface {
|
|
ValidateSignature(auth fleet.Auth) (fleet.Auth, error)
|
|
ValidateResponse(auth fleet.Auth) error
|
|
}
|
|
|
|
type validator struct {
|
|
context *dsig.ValidationContext
|
|
clock *dsig.Clock
|
|
metadata Metadata
|
|
expectedAudiences []string
|
|
}
|
|
|
|
func Clock(clock *dsig.Clock) func(v *validator) {
|
|
return func(v *validator) {
|
|
v.clock = clock
|
|
}
|
|
}
|
|
|
|
func WithExpectedAudience(audiences ...string) func(v *validator) {
|
|
return func(v *validator) {
|
|
v.expectedAudiences = audiences
|
|
}
|
|
}
|
|
|
|
// NewValidator is used to validate the response to an auth request.
|
|
// metadata is from the IDP.
|
|
func NewValidator(metadata Metadata, opts ...func(v *validator)) (Validator, error) {
|
|
v := validator{
|
|
metadata: metadata,
|
|
}
|
|
|
|
var idpCertStore dsig.MemoryX509CertificateStore
|
|
for _, key := range v.metadata.IDPSSODescriptor.KeyDescriptors {
|
|
if len(key.KeyInfo.X509Data.X509Certificates) == 0 {
|
|
return nil, errors.New("missing x509 cert")
|
|
}
|
|
certData, err := base64.StdEncoding.DecodeString(strings.TrimSpace(key.KeyInfo.X509Data.X509Certificates[0].Data))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decoding idp x509 cert: %w", err)
|
|
}
|
|
cert, err := x509.ParseCertificate(certData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing idp x509 cert: %w", err)
|
|
}
|
|
idpCertStore.Roots = append(idpCertStore.Roots, cert)
|
|
}
|
|
for _, opt := range opts {
|
|
opt(&v)
|
|
}
|
|
if v.clock == nil {
|
|
v.clock = dsig.NewRealClock()
|
|
}
|
|
v.context = dsig.NewDefaultValidationContext(&idpCertStore)
|
|
v.context.Clock = v.clock
|
|
return &v, nil
|
|
}
|
|
|
|
func (v *validator) ValidateResponse(auth fleet.Auth) error {
|
|
info := auth.(*resp)
|
|
// make sure response is current
|
|
onOrAfter, err := time.Parse(time.RFC3339, info.response.Assertion.Conditions.NotOnOrAfter)
|
|
if err != nil {
|
|
return fmt.Errorf("missing timestamp from condition: %w", err)
|
|
}
|
|
notBefore, err := time.Parse(time.RFC3339, info.response.Assertion.Conditions.NotBefore)
|
|
if err != nil {
|
|
return fmt.Errorf("missing timestamp from condition: %w", err)
|
|
}
|
|
currentTime := v.clock.Now()
|
|
if currentTime.After(onOrAfter) {
|
|
return errors.New("response expired")
|
|
}
|
|
if currentTime.Before(notBefore) {
|
|
return errors.New("response too early")
|
|
}
|
|
|
|
verifiesAudience := false
|
|
for _, audience := range v.expectedAudiences {
|
|
if info.response.Assertion.Conditions.AudienceRestriction.Audience == audience {
|
|
verifiesAudience = true
|
|
break
|
|
}
|
|
}
|
|
if !verifiesAudience {
|
|
return errors.New("wrong audience:" + info.response.Assertion.Conditions.AudienceRestriction.Audience)
|
|
}
|
|
|
|
if auth.UserID() == "" {
|
|
return errors.New("missing user id")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (v *validator) ValidateSignature(auth fleet.Auth) (fleet.Auth, error) {
|
|
info := auth.(*resp)
|
|
status, err := info.status()
|
|
if err != nil {
|
|
return nil, errors.New("missing or malformed response")
|
|
}
|
|
if status != Success {
|
|
return nil, fmt.Errorf("response status %s", info.statusDescription())
|
|
}
|
|
|
|
// Examine the response for attempts to exploit weaknesses in Go's
|
|
// encoding/xml
|
|
decoded := info.rawResponse()
|
|
err = rtvalidator.Validate(bytes.NewReader(decoded))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("response XML failed validation: %w", err)
|
|
}
|
|
|
|
doc := etree.NewDocument()
|
|
err = doc.ReadFromBytes(decoded)
|
|
if err != nil || doc.Root() == nil {
|
|
return nil, fmt.Errorf("parsing xml response: %w", err)
|
|
}
|
|
elt := doc.Root()
|
|
signed, err := v.validateSignature(elt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("signing verification failed: %w", err)
|
|
}
|
|
// We've verified that the response hasn't been tampered with at this point
|
|
signedDoc := etree.NewDocument()
|
|
signedDoc.SetRoot(signed)
|
|
buffer, err := signedDoc.WriteToBytes()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating signed doc buffer: %w", err)
|
|
}
|
|
var response Response
|
|
err = xml.Unmarshal(buffer, &response)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unmarshalling signed doc: %w", err)
|
|
}
|
|
info.setResponse(&response)
|
|
return info, nil
|
|
}
|
|
|
|
func (v *validator) validateSignature(elt *etree.Element) (*etree.Element, error) {
|
|
validated, err := v.context.Validate(elt)
|
|
if err == nil {
|
|
// If entire doc is signed, success, we're done.
|
|
return validated, nil
|
|
}
|
|
if err == dsig.ErrMissingSignature {
|
|
// If entire document is not signed find signed assertions, remove assertions
|
|
// that are not signed.
|
|
err = v.validateAssertionSignature(elt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return elt, nil
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
func (v *validator) validateAssertionSignature(elt *etree.Element) error {
|
|
validateAssertion := func(ctx etreeutils.NSContext, unverified *etree.Element) error {
|
|
if unverified.Parent() != elt {
|
|
return fmt.Errorf("assertion with unexpected parent: %s", unverified.Parent().Tag)
|
|
}
|
|
// Remove assertions that are not signed.
|
|
detached, err := etreeutils.NSDetatch(ctx, unverified)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
signed, err := v.context.Validate(detached)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
elt.RemoveChild(unverified)
|
|
elt.AddChild(signed)
|
|
return nil
|
|
}
|
|
return etreeutils.NSFindIterate(elt, "urn:oasis:names:tc:SAML:2.0:assertion", "Assertion", validateAssertion)
|
|
}
|
|
|
|
const (
|
|
idPrefix = "id"
|
|
idSize = 16
|
|
idAlphabet = `1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ`
|
|
)
|
|
|
|
// There isn't anything in the SAML spec that tells us what is valid inside an
|
|
// ID other than expecting that it has to be unique and valid XML. ADFS blows
|
|
// up on '=' in the ID, so we are using an alphabet that we know works.
|
|
//
|
|
// Azure IdP requires that the ID begin with a character so we use the constant
|
|
// prefix.
|
|
func generateSAMLValidID() (string, error) {
|
|
randomBytes := make([]byte, idSize)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
for i := 0; i < idSize; i++ {
|
|
randomBytes[i] = idAlphabet[randomBytes[i]%byte(len(idAlphabet))]
|
|
}
|
|
return idPrefix + string(randomBytes), nil
|
|
}
|
|
|
|
func ValidateAudiences(metadata Metadata, auth fleet.Auth, audiences ...string) error {
|
|
validator, err := NewValidator(metadata, WithExpectedAudience(audiences...))
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("create validator from metadata: %w", err)
|
|
}
|
|
// make sure the response hasn't been tampered with
|
|
auth, err = validator.ValidateSignature(auth)
|
|
if err != nil {
|
|
return fmt.Errorf("signature validation failed: %w", err)
|
|
}
|
|
// make sure the response isn't stale
|
|
err = validator.ValidateResponse(auth)
|
|
if err != nil {
|
|
return fmt.Errorf("response validation failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|