fleet/server/service/endpoint_utils.go
Roberto Dip 2fcb27ed3f
add headers denoting capabilities between fleet server / desktop / orbit (#7833)
This adds a new mechanism to allow us to handle compatibility issues between Orbit, Fleet Server and Fleet Desktop.

The general idea is to _always_ send a custom header of the form:

```
fleet-capabilities-header = "X-Fleet-Capabilities:" capabilities
capabilities              = capability * (,)
capability                = string
```

Both from the server to the clients (Orbit, Fleet Desktop) and vice-versa. For an example, see: 8c0bbdd291

Also, the following applies:

- Backwards compat: if the header is not present, assume that orbit/fleet doesn't have the capability
- The current capabilities endpoint will be removed

### Motivation

This solution is trying to solve the following problems:

- We have three independent processes communicating with each other (Fleet Desktop, Orbit and Fleet Server). Each process can be updated independently, and therefore we need a way for each process to know what features are supported by its peers.
- We originally implemented a dedicated API endpoint in the server that returned a list of the capabilities (or "features") enabled, we found this, and any other server-only solution (like API versioning) to be insufficient because:
  - There are cases in which the server also needs to know which features are supported by its clients
  - Clients needed to poll for changes to detect if the capabilities supported by the server change, by sending the capabilities on each request we have a much cleaner way to handling different responses.
- We are also introducing an unauthenticated endpoint to get the server features, this gives us flexibility if we need to implement different authentication mechanisms, and was one of the pitfalls of the first implementation.

Related to https://github.com/fleetdm/fleet/issues/7929
2022-09-26 07:53:53 -03:00

532 lines
15 KiB
Go

package service
import (
"bufio"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"strings"
"github.com/fleetdm/fleet/v4/server/contexts/capabilities"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/log"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/gorilla/mux"
)
type handlerFunc func(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error)
// parseTag parses a `url` tag and whether it's optional or not, which is an optional part of the tag
func parseTag(tag string) (string, bool, error) {
parts := strings.Split(tag, ",")
switch len(parts) {
case 0:
return "", false, fmt.Errorf("Error parsing %s: too few parts", tag)
case 1:
return tag, false, nil
case 2:
return parts[0], parts[1] == "optional", nil
default:
return "", false, fmt.Errorf("Error parsing %s: too many parts", tag)
}
}
// allFields returns all the fields for a struct, including the ones from embedded structs
func allFields(ifv reflect.Value) []reflect.StructField {
if ifv.Kind() == reflect.Ptr {
ifv = ifv.Elem()
}
if ifv.Kind() != reflect.Struct {
return nil
}
var fields []reflect.StructField
if !ifv.IsValid() {
return nil
}
t := ifv.Type()
for i := 0; i < ifv.NumField(); i++ {
v := ifv.Field(i)
if v.Kind() == reflect.Struct && t.Field(i).Anonymous {
fields = append(fields, allFields(v)...)
continue
}
fields = append(fields, ifv.Type().Field(i))
}
return fields
}
type requestDecoder interface {
DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error)
}
// makeDecoder creates a decoder for the type for the struct passed on. If the
// struct has at least 1 json tag it'll unmarshall the body. If the struct has
// a `url` tag with value list_options it'll gather fleet.ListOptions from the
// URL (similarly for host_options, carve_options, user_options that derive
// from the common list_options).
//
// Finally, any other `url` tag will be treated as a path variable (of the form
// /path/{name} in the route's path) from the URL path pattern, and it'll be
// decoded and set accordingly. Variables can be optional by setting the tag as
// follows: `url:"some-id,optional"`.
// The "list_options" are optional by default and it'll ignore the optional
// portion of the tag.
//
// If iface implements the requestDecoder interface, it returns a function that
// calls iface.DecodeRequest(ctx, r) - i.e. the value itself fully controls its
// own decoding.
func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
if iface == nil {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
return nil, nil
}
}
if rd, ok := iface.(requestDecoder); ok {
return func(ctx context.Context, r *http.Request) (interface{}, error) {
return rd.DecodeRequest(ctx, r)
}
}
t := reflect.TypeOf(iface)
if t.Kind() != reflect.Struct {
panic(fmt.Sprintf("makeDecoder only understands structs, not %T", iface))
}
return func(ctx context.Context, r *http.Request) (interface{}, error) {
v := reflect.New(t)
nilBody := false
buf := bufio.NewReader(r.Body)
if _, err := buf.Peek(1); err == io.EOF {
nilBody = true
} else {
var body io.Reader = buf
if r.Header.Get("content-encoding") == "gzip" {
gzr, err := gzip.NewReader(buf)
if err != nil {
return nil, err
}
defer gzr.Close()
body = gzr
}
req := v.Interface()
if err := json.NewDecoder(body).Decode(req); err != nil {
return nil, err
}
v = reflect.ValueOf(req)
}
for _, f := range allFields(v) {
field := v.Elem().FieldByName(f.Name)
urlTagValue, ok := f.Tag.Lookup("url")
optional := false
var err error
if ok {
urlTagValue, optional, err = parseTag(urlTagValue)
if err != nil {
return nil, err
}
switch urlTagValue {
case "list_options":
opts, err := listOptionsFromRequest(r)
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(opts))
case "user_options":
opts, err := userListOptionsFromRequest(r)
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(opts))
case "host_options":
opts, err := hostListOptionsFromRequest(r)
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(opts))
case "carve_options":
opts, err := carveListOptionsFromRequest(r)
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(opts))
default:
switch field.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, err := intFromRequest(r, urlTagValue)
if err != nil {
if err == errBadRoute && optional {
continue
}
return nil, err
}
field.SetInt(v)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v, err := uintFromRequest(r, urlTagValue)
if err != nil {
if err == errBadRoute && optional {
continue
}
return nil, err
}
field.SetUint(v)
case reflect.String:
v, err := stringFromRequest(r, urlTagValue)
if err != nil {
if err == errBadRoute && optional {
continue
}
return nil, err
}
field.SetString(v)
default:
return nil, fmt.Errorf("unsupported type for field %s for 'url' decoding: %s", urlTagValue, field.Kind())
}
}
}
_, jsonExpected := f.Tag.Lookup("json")
if jsonExpected && nilBody {
return nil, errors.New("Expected JSON Body")
}
queryTagValue, ok := f.Tag.Lookup("query")
if ok {
queryTagValue, optional, err = parseTag(queryTagValue)
if err != nil {
return nil, err
}
queryVal := r.URL.Query().Get(queryTagValue)
// if optional and it's a ptr, leave as nil
if queryVal == "" {
if optional {
continue
}
return nil, fmt.Errorf("Param %s is required", f.Name)
}
if field.Kind() == reflect.Ptr {
// create the new instance of whatever it is
field.Set(reflect.New(field.Type().Elem()))
field = field.Elem()
}
switch field.Kind() {
case reflect.String:
field.SetString(queryVal)
case reflect.Uint:
queryValUint, err := strconv.Atoi(queryVal)
if err != nil {
return nil, fmt.Errorf("parsing uint from query: %w", err)
}
field.SetUint(uint64(queryValUint))
case reflect.Bool:
field.SetBool(queryVal == "1" || queryVal == "true")
case reflect.Int:
queryValInt := 0
switch queryTagValue {
case "order_direction":
switch queryVal {
case "desc":
queryValInt = int(fleet.OrderDescending)
case "asc":
queryValInt = int(fleet.OrderAscending)
case "":
queryValInt = int(fleet.OrderAscending)
default:
return fleet.ListOptions{},
errors.New("unknown order_direction: " + queryVal)
}
default:
queryValInt, err = strconv.Atoi(queryVal)
if err != nil {
return nil, fmt.Errorf("parsing int from query: %w", err)
}
}
field.SetInt(int64(queryValInt))
default:
return nil, fmt.Errorf("Cant handle type for field %s %s", f.Name, field.Kind())
}
}
}
return v.Interface(), nil
}
}
type authEndpointer struct {
svc fleet.Service
opts []kithttp.ServerOption
r *mux.Router
authFunc func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint
versions []string
startingAtVersion string
endingAtVersion string
alternativePaths []string
customMiddleware []endpoint.Middleware
usePathPrefix bool
}
func newDeviceAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
authFunc := func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint {
return authenticatedDevice(svc, logger, next)
}
// Inject the fleet.CapabilitiesHeader header to the response for device endpoints
opts = append(opts, capabilitiesResponseFunc(fleet.ServerDeviceCapabilities))
// Add the capabilities reported by the device to the request context
opts = append(opts, capabilitiesContextFunc())
return &authEndpointer{
svc: svc,
opts: opts,
r: r,
authFunc: authFunc,
versions: versions,
}
}
func newUserAuthenticatedEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
return &authEndpointer{
svc: svc,
opts: opts,
r: r,
authFunc: authenticatedUser,
versions: versions,
}
}
func newHostAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
authFunc := func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint {
return authenticatedHost(svc, logger, next)
}
return &authEndpointer{
svc: svc,
opts: opts,
r: r,
authFunc: authFunc,
versions: versions,
}
}
func newOrbitAuthenticatedEndpointer(svc fleet.Service, logger log.Logger, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
authFunc := func(svc fleet.Service, next endpoint.Endpoint) endpoint.Endpoint {
return authenticatedOrbitHost(svc, logger, next)
}
// Inject the fleet.Capabilities header to the response for Orbit hosts
opts = append(opts, capabilitiesResponseFunc(fleet.ServerOrbitCapabilities))
// Add the capabilities reported by Orbit to the request context
opts = append(opts, capabilitiesContextFunc())
return &authEndpointer{
svc: svc,
opts: opts,
r: r,
authFunc: authFunc,
versions: versions,
}
}
func newNoAuthEndpointer(svc fleet.Service, opts []kithttp.ServerOption, r *mux.Router, versions ...string) *authEndpointer {
return &authEndpointer{
svc: svc,
opts: opts,
r: r,
authFunc: unauthenticatedRequest,
versions: versions,
}
}
var pathReplacer = strings.NewReplacer(
"/", "_",
"{", "_",
"}", "_",
)
func getNameFromPathAndVerb(verb, path string) string {
return strings.ToLower(verb) + "_" +
pathReplacer.Replace(strings.TrimPrefix(strings.TrimRight(path, "/"), "/api/_version_/fleet/"))
}
func capabilitiesResponseFunc(capabilities fleet.CapabilityMap) kithttp.ServerOption {
return kithttp.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
writeCapabilitiesHeader(w, capabilities)
return ctx
})
}
func capabilitiesContextFunc() kithttp.ServerOption {
return kithttp.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
return capabilities.NewContext(ctx, r)
})
}
func writeCapabilitiesHeader(w http.ResponseWriter, capabilities fleet.CapabilityMap) {
if len(capabilities) == 0 {
return
}
w.Header().Set(fleet.CapabilitiesHeader, capabilities.String())
}
func (e *authEndpointer) POST(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "POST")
}
func (e *authEndpointer) GET(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "GET")
}
func (e *authEndpointer) PATCH(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "PATCH")
}
func (e *authEndpointer) DELETE(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "DELETE")
}
func (e *authEndpointer) HEAD(path string, f handlerFunc, v interface{}) {
e.handleEndpoint(path, f, v, "HEAD")
}
// PathHandler registers a handler for the verb and path. The pathHandler is
// a function that receives the actual path to which it will be mounted, and
// returns the actual http.Handler that will handle this endpoint. This is for
// when the handler needs to know on which path it was called.
func (e *authEndpointer) PathHandler(verb, path string, pathHandler func(path string) http.Handler) {
e.handlePathHandler(path, pathHandler, verb)
}
func (e *authEndpointer) handlePathHandler(path string, pathHandler func(path string) http.Handler, verb string) {
versions := e.versions
if e.startingAtVersion != "" {
startIndex := -1
for i, version := range versions {
if version == e.startingAtVersion {
startIndex = i
break
}
}
if startIndex == -1 {
panic("StartAtVersion is not part of the valid versions")
}
versions = versions[startIndex:]
}
if e.endingAtVersion != "" {
endIndex := -1
for i, version := range versions {
if version == e.endingAtVersion {
endIndex = i
break
}
}
if endIndex == -1 {
panic("EndAtVersion is not part of the valid versions")
}
versions = versions[:endIndex+1]
}
// if a version doesn't have a deprecation version, or the ending version is the latest one, then it's part of the
// latest
if e.endingAtVersion == "" || e.endingAtVersion == e.versions[len(e.versions)-1] {
versions = append(versions, "latest")
}
versionedPath := strings.Replace(path, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
nameAndVerb := getNameFromPathAndVerb(verb, path)
if e.usePathPrefix {
e.r.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
} else {
e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
}
for _, alias := range e.alternativePaths {
nameAndVerb := getNameFromPathAndVerb(verb, alias)
versionedPath := strings.Replace(alias, "/_version_/", fmt.Sprintf("/{fleetversion:(?:%s)}/", strings.Join(versions, "|")), 1)
if e.usePathPrefix {
e.r.PathPrefix(versionedPath).Handler(pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
} else {
e.r.Handle(versionedPath, pathHandler(versionedPath)).Name(nameAndVerb).Methods(verb)
}
}
}
func (e *authEndpointer) handleHTTPHandler(path string, h http.Handler, verb string) {
self := func(_ string) http.Handler { return h }
e.handlePathHandler(path, self, verb)
}
func (e *authEndpointer) handleEndpoint(path string, f handlerFunc, v interface{}, verb string) {
endpoint := e.makeEndpoint(f, v)
e.handleHTTPHandler(path, endpoint, verb)
}
func (e *authEndpointer) makeEndpoint(f handlerFunc, v interface{}) http.Handler {
next := func(ctx context.Context, request interface{}) (interface{}, error) {
return f(ctx, request, e.svc)
}
endp := e.authFunc(e.svc, next)
// apply middleware in reverse order so that the first wraps the second
// wraps the third etc.
for i := len(e.customMiddleware) - 1; i >= 0; i-- {
mw := e.customMiddleware[i]
endp = mw(endp)
}
return newServer(endp, makeDecoder(v), e.opts)
}
func (e *authEndpointer) StartingAtVersion(version string) *authEndpointer {
ae := *e
ae.startingAtVersion = version
return &ae
}
func (e *authEndpointer) EndingAtVersion(version string) *authEndpointer {
ae := *e
ae.endingAtVersion = version
return &ae
}
func (e *authEndpointer) WithAltPaths(paths ...string) *authEndpointer {
ae := *e
ae.alternativePaths = paths
return &ae
}
func (e *authEndpointer) WithCustomMiddleware(mws ...endpoint.Middleware) *authEndpointer {
ae := *e
ae.customMiddleware = mws
return &ae
}
func (e *authEndpointer) UsePathPrefix() *authEndpointer {
ae := *e
ae.usePathPrefix = true
return &ae
}