add backwards compatiblity to orbit enrollment via the API (#7921)

This commit is contained in:
Roberto Dip 2022-09-26 11:44:09 -03:00 committed by GitHub
parent 2fcb27ed3f
commit a509bdd0ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 34 deletions

View File

@ -491,27 +491,36 @@ func main() {
if err != nil {
return fmt.Errorf("error new orbit client: %w", err)
}
orbitNodeKey, err := getOrbitNodeKeyOrEnroll(orbitClient, c.String("root-dir"))
if err != nil {
return fmt.Errorf("error enroll: %w", err)
// ping the server to get the latest capabilities
if err := orbitClient.Ping(); err != nil {
return fmt.Errorf("error pinging the server: %w", err)
}
const orbitFlagsUpdateInterval = 30 * time.Second
flagRunner, err := update.NewFlagRunner(orbitClient, update.FlagUpdateOptions{
CheckInterval: orbitFlagsUpdateInterval,
RootDir: c.String("root-dir"),
OrbitNodeKey: orbitNodeKey,
})
if err != nil {
return err
if orbitClient.GetServerCapabilities().Has(fleet.CapabilityOrbitEndpoints) {
log.Info().Msg("Orbit endpoints are enabled")
orbitNodeKey, err := getOrbitNodeKeyOrEnroll(orbitClient, c.String("root-dir"))
if err != nil {
return fmt.Errorf("error enroll: %w", err)
}
const orbitFlagsUpdateInterval = 30 * time.Second
flagRunner, err := update.NewFlagRunner(orbitClient, update.FlagUpdateOptions{
CheckInterval: orbitFlagsUpdateInterval,
RootDir: c.String("root-dir"),
OrbitNodeKey: orbitNodeKey,
})
if err != nil {
return err
}
// do the initial flags update
_, err = flagRunner.DoFlagsUpdate()
if err != nil {
// just log, OK to continue, since we will retry
log.Info().Err(err).Msg("Initial flags update failed")
}
g.Add(flagRunner.Execute, flagRunner.Interrupt)
}
// do the initial flags update
_, err = flagRunner.DoFlagsUpdate()
if err != nil {
// just log, OK to continue, since we will retry
log.Info().Err(err).Msg("Initial flags update failed")
}
g.Add(flagRunner.Execute, flagRunner.Interrupt)
// --force is sometimes needed when an older osquery process has not
// exited properly
@ -545,6 +554,13 @@ func main() {
registerExtensionRunner(&g, r.ExtensionSocketPath(), deviceAuthToken)
checkerClient, err := service.NewOrbitClient(fleetURL, c.String("fleet-certificate"), c.Bool("insecure"), enrollSecret, uuidStr, capabilities)
if err != nil {
return fmt.Errorf("new client for capabilities checker: %w", err)
}
capabilitiesChecker := newCapabilitiesChecker(checkerClient)
g.Add(capabilitiesChecker.actor())
if c.Bool("fleet-desktop") {
desktopRunner := newDesktopRunner(desktopPath, fleetURL, deviceAuthToken, c.String("fleet-certificate"), c.Bool("insecure"))
g.Add(desktopRunner.actor())
@ -837,3 +853,60 @@ var versionCommand = &cli.Command{
return nil
},
}
// capabilitiesChecker is a helper to restart Orbit as soon as certain capabilities
// are changed in the server.
//
// This struct and its methods are designed to play nicely with `oklog.Group`.
type capabilitiesChecker struct {
client *service.OrbitClient
interruptCh chan struct{} // closed when interrupt is triggered
executeDoneCh chan struct{} // closed when execute returns
}
func newCapabilitiesChecker(client *service.OrbitClient) *capabilitiesChecker {
return &capabilitiesChecker{
client: client,
interruptCh: make(chan struct{}),
executeDoneCh: make(chan struct{}),
}
}
func (f *capabilitiesChecker) actor() (func() error, func(error)) {
return f.execute, f.interrupt
}
// execute will poll the server for capabilities and emit a stop signal to restart
// Orbit if certain capabilities are enabled.
//
// You need to add an explicit check for each capability you want to watch for
func (f *capabilitiesChecker) execute() error {
defer close(f.executeDoneCh)
capabilitiesCHeckTicker := time.NewTicker(5 * time.Minute)
for {
select {
case <-capabilitiesCHeckTicker.C:
oldCapabilities := f.client.GetServerCapabilities()
// ping the server to get the latest capabilities
if err := f.client.Ping(); err != nil {
log.Error().Err(err).Msg("pinging the server")
continue
}
newCapabilities := f.client.GetServerCapabilities()
if oldCapabilities.Has(fleet.CapabilityOrbitEndpoints) != newCapabilities.Has(fleet.CapabilityOrbitEndpoints) {
log.Info().Msg("orbit endpoints capability changed, restarting")
return nil
}
case <-f.interruptCh:
return nil
}
}
}
func (f *capabilitiesChecker) interrupt(err error) {
log.Debug().Err(err).Msg("interrupt capabilitiesChecker")
close(f.interruptCh) // Signal execute to return.
<-f.executeDoneCh // Wait for execute to return.
}

View File

@ -34,10 +34,29 @@ func (c *CapabilityMap) String() string {
return strings.Join(capabilities, ",")
}
// Has returns true if the CapabilityMap contains the given capability.
func (c CapabilityMap) Has(capability Capability) bool {
_, ok := c[capability]
return ok
}
// The following are the capabilities that Fleet supports. These can be used by
// the Fleet server, Orbit or Fleet Desktop to communicate that a given feature
// is supported.
const (
// CapabilityOrbitEndpoints denotes the presence of server endpoints
// dedicated to communicating with Orbit. These endpoints start with
// `/api/fleet/orbit`, and allow enrolling a host through Orbit among other
// functionality.
CapabilityOrbitEndpoints Capability = "orbit_endpoints"
)
// ServerOrbitCapabilities is a set of capabilities that server-side,
// Orbit-related endpoint supports.
// **it shouldn't be modified at runtime**
var ServerOrbitCapabilities = CapabilityMap{}
var ServerOrbitCapabilities = CapabilityMap{
CapabilityOrbitEndpoints: {},
}
// ServerDeviceCapabilities is a set of capabilities that server-side,
// Device-related endpoint supports.

View File

@ -53,13 +53,17 @@ func (bc *baseClient) parseResponse(verb, path string, response *http.Response,
)
}
if err := json.NewDecoder(response.Body).Decode(&responseDest); err != nil {
return fmt.Errorf("decode %s %s response: %w", verb, path, err)
}
bc.setServerCapabilities(response)
if e, ok := responseDest.(errorer); ok {
if e.error() != nil {
return fmt.Errorf("%s %s error: %w", verb, path, e.error())
if responseDest != nil {
if err := json.NewDecoder(response.Body).Decode(&responseDest); err != nil {
return fmt.Errorf("decode %s %s response: %w", verb, path, err)
}
if e, ok := responseDest.(errorer); ok {
if e.error() != nil {
return fmt.Errorf("%s %s error: %w", verb, path, e.error())
}
}
}
@ -82,11 +86,8 @@ func (bc *baseClient) setServerCapabilities(response *http.Response) {
bc.serverCapabilities.PopulateFromString(capabilities)
}
// HasServerCapability returns a boolean indicating if the server supports the
// given capability
func (bc *baseClient) HasServerCapability(capability fleet.Capability) bool {
_, ok := bc.serverCapabilities[capability]
return ok
func (bc *baseClient) GetServerCapabilities() fleet.CapabilityMap {
return bc.serverCapabilities
}
// setClientCapabilities header is used to set a header with the client
@ -155,6 +156,7 @@ func newBaseClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix strin
insecureSkipVerify: insecureSkipVerify,
urlPrefix: urlPrefix,
clientCapabilities: capabilities,
serverCapabilities: fleet.CapabilityMap{},
}
return client, nil
}

View File

@ -149,7 +149,7 @@ func TestServerCapabilities(t *testing.T) {
err = bc.parseResponse("", "", response, &struct{}{})
require.NoError(t, err)
require.True(t, bc.HasServerCapability(testCapability))
require.True(t, bc.GetServerCapabilities().Has(testCapability))
// later on, the server is downgraded and no longer has the capability
response = &http.Response{
@ -160,7 +160,7 @@ func TestServerCapabilities(t *testing.T) {
err = bc.parseResponse("", "", response, &struct{}{})
require.NoError(t, err)
require.Equal(t, fleet.CapabilityMap{}, bc.serverCapabilities)
require.False(t, bc.HasServerCapability(testCapability))
require.False(t, bc.GetServerCapabilities().Has(testCapability))
// after an upgrade, the server has many capabilities
response = &http.Response{
@ -174,6 +174,6 @@ func TestServerCapabilities(t *testing.T) {
testCapability: {},
fleet.Capability("test_capability_2"): {},
}, bc.serverCapabilities)
require.True(t, bc.HasServerCapability(testCapability))
require.True(t, bc.HasServerCapability(fleet.Capability("test_capability")))
require.True(t, bc.GetServerCapabilities().Has(testCapability))
require.True(t, bc.GetServerCapabilities().Has(fleet.Capability("test_capability")))
}

View File

@ -3,6 +3,7 @@ package service
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -78,3 +79,15 @@ func (oc *OrbitClient) GetConfig(orbitNodeKey string) (json.RawMessage, error) {
return resp.Flags, nil
}
func (oc *OrbitClient) Ping() error {
verb, path := "HEAD", "/api/latest/fleet/orbit_ping"
err := oc.request(verb, path, nil, nil)
if err == nil || errors.Is(err, notFoundErr{}) {
// notFound is ok, it means an old server without the capabilities header
return nil
}
return err
}