fleet/ee/server/licensing/licensing.go

129 lines
3.0 KiB
Go

package licensing
import (
"crypto/ecdsa"
"crypto/x509"
_ "embed"
"encoding/pem"
"errors"
"fmt"
"time"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/golang-jwt/jwt/v4"
)
const (
expectedAlgorithm = "ES256"
expectedIssuer = "Fleet Device Management Inc."
)
//go:embed pubkey.pem
var pubKeyPEM []byte
// loadPublicKey loads the public key from pubkey.pem.
func loadPublicKey() (*ecdsa.PublicKey, error) {
block, _ := pem.Decode(pubKeyPEM)
if block == nil {
return nil, errors.New("no key block found in pem")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse ecdsa key: %w", err)
}
if pub, ok := pub.(*ecdsa.PublicKey); ok {
return pub, nil
}
return nil, fmt.Errorf("%T is not *ecdsa.PublicKey", pub)
}
// LoadLicense loads and validates the license key.
func LoadLicense(licenseKey string) (*fleet.LicenseInfo, error) {
// No license key
if licenseKey == "" {
return &fleet.LicenseInfo{Tier: fleet.TierFree}, nil
}
parsedToken, err := jwt.ParseWithClaims(
licenseKey,
&licenseClaims{},
// Always use the same public key
func(*jwt.Token) (interface{}, error) {
return loadPublicKey()
},
)
if err != nil {
v, _ := err.(*jwt.ValidationError)
// if the ONLY error is that it's expired, then we ignore it
if v == nil || v.Errors != jwt.ValidationErrorExpired {
return nil, fmt.Errorf("parse license: %w", err)
}
parsedToken.Valid = true
}
license, err := validate(parsedToken)
if err != nil {
return nil, fmt.Errorf("validate license: %w", err)
}
// for backwards compatibility we'll convert basic tier to premium
license.ForceUpgrade()
return license, nil
}
type licenseClaims struct {
// jwt.StandardClaims includes validation for iat, nbf, and exp.
jwt.StandardClaims
Tier string `json:"tier"`
Devices int `json:"devices"`
Note string `json:"note"`
}
func validate(token *jwt.Token) (*fleet.LicenseInfo, error) {
// token.IssuedAt, token.ExpiresAt, token.NotBefore already validated by JWT
// library.
if !token.Valid {
// ParseWithClaims should have errored already, but double-check here
return nil, errors.New("token invalid")
}
if token.Method.Alg() != expectedAlgorithm {
return nil, fmt.Errorf("unexpected algorithm %s", token.Method.Alg())
}
var claims *licenseClaims
claims, ok := token.Claims.(*licenseClaims)
if !ok || claims == nil {
return nil, fmt.Errorf("unexpected claims type %T", token.Claims)
}
if claims.Devices == 0 {
return nil, errors.New("missing devices")
}
if claims.Tier == "" {
return nil, errors.New("missing tier")
}
if claims.ExpiresAt == 0 {
return nil, errors.New("missing exp")
}
if claims.Issuer != expectedIssuer {
return nil, fmt.Errorf("unexpected issuer %s", claims.Issuer)
}
return &fleet.LicenseInfo{
Tier: claims.Tier,
Organization: claims.Subject,
DeviceCount: claims.Devices,
Expiration: time.Unix(claims.ExpiresAt, 0),
Note: claims.Note,
}, nil
}