fleet/server/service/client.go
Michal Nicpon 075702113a
Print version warning when using fleetctl (#4139)
* Remove deprecated call in fleetctl
* Remove duplicate error returned by app.Run in tests
2022-02-14 09:43:34 -07:00

276 lines
7.1 KiB
Go

package service
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/fleetdm/fleet/v4/pkg/fleethttp"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
)
// httpClient interface allows the HTTP methods to be mocked.
type httpClient interface {
Do(req *http.Request) (*http.Response, error)
}
type Client struct {
addr string
baseURL *url.URL
urlPrefix string
token string
http httpClient
insecureSkipVerify bool
writer io.Writer
}
type ClientOption func(*Client) error
func NewClient(addr string, insecureSkipVerify bool, rootCA, urlPrefix string, options ...ClientOption) (*Client, error) {
// TODO #265 refactor all optional parameters to functional options
// API breaking change, needs a major version release
baseURL, err := url.Parse(addr)
if err != nil {
return nil, fmt.Errorf("parsing URL: %w", err)
}
if baseURL.Scheme != "https" && !strings.Contains(baseURL.Host, "localhost") && !strings.Contains(baseURL.Host, "127.0.0.1") {
return nil, errors.New("address must start with https:// for remote connections")
}
rootCAPool := x509.NewCertPool()
if rootCA != "" {
// read in the root cert file specified in the context
certs, err := ioutil.ReadFile(rootCA)
if err != nil {
return nil, fmt.Errorf("reading root CA: %w", err)
}
// add certs to pool
if ok := rootCAPool.AppendCertsFromPEM(certs); !ok {
return nil, errors.New("failed to add certificates to root CA pool")
}
} else if !insecureSkipVerify {
// Use only the system certs (doesn't work on Windows)
rootCAPool, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("loading system cert pool: %w", err)
}
}
httpClient := fleethttp.NewClient(fleethttp.WithTLSClientConfig(&tls.Config{
InsecureSkipVerify: insecureSkipVerify,
RootCAs: rootCAPool,
}))
client := &Client{
addr: addr,
baseURL: baseURL,
http: httpClient,
insecureSkipVerify: insecureSkipVerify,
urlPrefix: urlPrefix,
}
for _, option := range options {
err := option(client)
if err != nil {
return nil, err
}
}
return client, nil
}
func EnableClientDebug() ClientOption {
return func(c *Client) error {
httpClient, ok := c.http.(*http.Client)
if !ok {
return errors.New("client is not *http.Client")
}
httpClient.Transport = &logRoundTripper{roundtripper: httpClient.Transport}
return nil
}
}
func SetClientWriter(w io.Writer) ClientOption {
return func(c *Client) error {
c.writer = w
return nil
}
}
func (c *Client) doContextWithHeaders(ctx context.Context, verb, path, rawQuery string, params interface{}, headers map[string]string) (*http.Response, error) {
var bodyBytes []byte
var err error
if params != nil {
bodyBytes, err = json.Marshal(params)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "marshaling json")
}
}
request, err := http.NewRequestWithContext(
ctx,
verb,
c.url(path, rawQuery).String(),
bytes.NewBuffer(bodyBytes),
)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "creating request object")
}
for k, v := range headers {
request.Header.Set(k, v)
}
resp, err := c.http.Do(request)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "do request")
}
if resp.Header.Get(fleet.HeaderLicenseKey) == fleet.HeaderLicenseValueExpired {
fleet.WriteExpiredLicenseBanner(c.writer)
}
return resp, nil
}
func (c *Client) Do(verb, path, rawQuery string, params interface{}) (*http.Response, error) {
return c.DoContext(context.Background(), verb, path, rawQuery, params)
}
func (c *Client) DoContext(ctx context.Context, verb, path, rawQuery string, params interface{}) (*http.Response, error) {
headers := map[string]string{
"Content-type": "application/json",
"Accept": "application/json",
}
return c.doContextWithHeaders(ctx, verb, path, rawQuery, params, headers)
}
func (c *Client) AuthenticatedDo(verb, path, rawQuery string, params interface{}) (*http.Response, error) {
if c.token == "" {
return nil, errors.New("authentication token is empty")
}
headers := map[string]string{
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.token),
}
return c.doContextWithHeaders(context.Background(), verb, path, rawQuery, params, headers)
}
func (c *Client) SetToken(t string) {
c.token = t
}
func (c *Client) url(path, rawQuery string) *url.URL {
u := *c.baseURL
u.Path = c.urlPrefix + path
u.RawQuery = rawQuery
return &u
}
// http.RoundTripper that will log debug information about the request and
// response, including paths, timing, and body.
//
// Inspired by https://stackoverflow.com/a/39528716/491710 and
// github.com/motemen/go-loghttp
type logRoundTripper struct {
roundtripper http.RoundTripper
}
// RoundTrip implements http.RoundTripper
func (l *logRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Log request
fmt.Fprintf(os.Stderr, "%s %s\n", req.Method, req.URL)
reqBody, err := req.GetBody()
if err != nil {
fmt.Fprintf(os.Stderr, "GetBody error: %v\n", err)
} else {
defer reqBody.Close()
if _, err := io.Copy(os.Stderr, reqBody); err != nil {
fmt.Fprintf(os.Stderr, "Copy body error: %v\n", err)
}
}
fmt.Fprintf(os.Stderr, "\n")
// Perform request using underlying roundtripper
start := time.Now()
res, err := l.roundtripper.RoundTrip(req)
if err != nil {
fmt.Fprintf(os.Stderr, "RoundTrip error: %v", err)
return nil, err
}
// Log response
took := time.Since(start).Truncate(time.Millisecond)
fmt.Fprintf(os.Stderr, "%s %s %s (%s)\n", res.Request.Method, res.Request.URL, res.Status, took)
resBody := &bytes.Buffer{}
resBodyReader := io.TeeReader(res.Body, resBody)
if _, err := io.Copy(os.Stderr, resBodyReader); err != nil {
fmt.Fprintf(os.Stderr, "Read body error: %v", err)
return nil, err
}
res.Body = io.NopCloser(resBody)
return res, nil
}
func (c *Client) authenticatedRequestWithQuery(params interface{}, verb string, path string, responseDest interface{}, query string) error {
response, err := c.AuthenticatedDo(verb, path, query, params)
if err != nil {
return fmt.Errorf("%s %s: %w", verb, path, err)
}
defer response.Body.Close()
switch response.StatusCode {
case http.StatusOK:
// ok
case http.StatusNotFound:
return notFoundErr{}
case http.StatusUnauthorized:
return ErrUnauthenticated
default:
return fmt.Errorf(
"%s %s received status %d %s",
verb, path,
response.StatusCode,
extractServerErrorText(response.Body),
)
}
err = json.NewDecoder(response.Body).Decode(&responseDest)
if err != nil {
return fmt.Errorf("decode %s %s response: %w", verb, path, err)
}
if e, ok := responseDest.(errorer); ok {
if e.error() != nil {
return fmt.Errorf("%s %s error: %s", verb, path, e.error())
}
}
return nil
}
func (c *Client) authenticatedRequest(params interface{}, verb string, path string, responseDest interface{}) error {
return c.authenticatedRequestWithQuery(params, verb, path, responseDest, "")
}