fleet/server/service/endpoint_campaigns.go

128 lines
4.2 KiB
Go

package service
import (
"context"
"encoding/json"
"net/http"
"regexp"
"strings"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/viewer"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/websocket"
kitlog "github.com/go-kit/kit/log"
gws "github.com/gorilla/websocket"
"github.com/igm/sockjs-go/v3/sockjs"
)
////////////////////////////////////////////////////////////////////////////////
// Stream Distributed Query Campaign Results and Metadata
////////////////////////////////////////////////////////////////////////////////
var reVersion = regexp.MustCompile(`\{fleetversion:\(\?:([^\}\)]+)\)\}`)
func makeStreamDistributedQueryCampaignResultsHandler(config config.ServerConfig, svc fleet.Service, logger kitlog.Logger) func(string) http.Handler {
opt := sockjs.DefaultOptions
opt.Websocket = true
opt.RawWebsocket = true
if config.WebsocketsAllowUnsafeOrigin {
opt.CheckOrigin = func(r *http.Request) bool {
return true
}
// sockjs uses gorilla websockets under-the-hood see https://github.com/igm/sockjs-go/blob/master/v3/sockjs/rawwebsocket.go#L12-L14
opt.WebsocketUpgrader = &gws.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
}}
}
return func(path string) http.Handler {
// expand the path's versions (with regex) to all literal paths (no regex),
// because sockjs requires the (static, literal) path prefix as argument to
// create the handler so that it can trim it from the request's URL to get
// the special path values (such as the session id).
matches := reVersion.FindStringSubmatch(path)
if len(matches) == 0 {
panic("unexpected path, could not expand fleetversion: " + path)
}
versions := strings.Split(matches[1], "|")
literalPaths := make([]string, len(versions))
for i, ver := range versions {
lp := reVersion.ReplaceAllStringFunc(path, func(_ string) string { return ver })
literalPaths[i] = lp
}
sockHandler := func(session sockjs.Session) {
conn := &websocket.Conn{Session: session}
defer func() {
if p := recover(); p != nil {
logger.Log("err", p, "msg", "panic in result handler")
conn.WriteJSONError("panic in result handler") //nolint:errcheck
}
session.Close(0, "none")
}()
// Receive the auth bearer token
token, err := conn.ReadAuthToken()
if err != nil {
logger.Log("err", err, "msg", "failed to read auth token")
return
}
// Authenticate with the token
vc, err := authViewer(context.Background(), string(token), svc)
if err != nil || !vc.CanPerformActions() {
logger.Log("err", err, "msg", "unauthorized viewer")
conn.WriteJSONError("unauthorized") //nolint:errcheck
return
}
ctx := viewer.NewContext(context.Background(), *vc)
msg, err := conn.ReadJSONMessage()
if err != nil {
logger.Log("err", err, "msg", "reading select_campaign JSON")
conn.WriteJSONError("error reading select_campaign") //nolint:errcheck
return
}
if msg.Type != "select_campaign" {
logger.Log("err", "unexpected msg type, expected select_campaign", "msg-type", msg.Type)
conn.WriteJSONError("expected select_campaign") //nolint:errcheck
return
}
var info struct {
CampaignID uint `json:"campaign_id"`
}
err = json.Unmarshal(*(msg.Data.(*json.RawMessage)), &info)
if err != nil {
logger.Log("err", err, "msg", "unmarshaling select_campaign data")
conn.WriteJSONError("error unmarshaling select_campaign data") //nolint:errcheck
return
}
if info.CampaignID == 0 {
logger.Log("err", "campaign ID not set")
conn.WriteJSONError("0 is not a valid campaign ID") //nolint:errcheck
return
}
svc.StreamCampaignResults(ctx, conn, info.CampaignID)
}
// multiplex the requests to each literal path that this endpoint support,
// with the corresponding sockjs handler to handle that specific path.
mux := http.NewServeMux()
for _, lp := range literalPaths {
// important: sockjs' path must not have the trailing path, but the mux
// needs it in order to match it as a path prefix (subtree).
sockPath := strings.TrimSuffix(lp, "/")
mux.Handle(lp, sockjs.NewHandler(sockPath, opt, sockHandler))
}
return mux
}
}