return http status 400 for bad json request body (#8287)

This commit is contained in:
Frank Sievertsen 2022-10-18 14:43:16 +02:00 committed by GitHub
parent 60e06c087f
commit 7c3d9f007a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 12 deletions

View File

@ -0,0 +1 @@
* return http status 400 if json decoding fails.

View File

@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"strconv"
@ -117,7 +118,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
if r.Header.Get("content-encoding") == "gzip" {
gzr, err := gzip.NewReader(buf)
if err != nil {
return nil, err
return nil, badRequestErr("gzip decoder error: %w", err)
}
defer gzr.Close()
body = gzr
@ -125,7 +126,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
req := v.Interface()
if err := json.NewDecoder(body).Decode(req); err != nil {
return nil, err
return nil, badRequestErr("json decoder error: %w", err)
}
v = reflect.ValueOf(req)
}
@ -180,7 +181,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
if err == errBadRoute && optional {
continue
}
return nil, err
return nil, badRequestErr("intFromRequest: %w", err)
}
field.SetInt(v)
@ -190,7 +191,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
if err == errBadRoute && optional {
continue
}
return nil, err
return nil, badRequestErr("uintFromRequest: %w", err)
}
field.SetUint(v)
@ -200,7 +201,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
if err == errBadRoute && optional {
continue
}
return nil, err
return nil, badRequestErr("stringFromRequest: %w", err)
}
field.SetString(v)
@ -212,7 +213,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
_, jsonExpected := f.Tag.Lookup("json")
if jsonExpected && nilBody {
return nil, errors.New("Expected JSON Body")
return nil, badRequest("Expected JSON Body")
}
queryTagValue, ok := f.Tag.Lookup("query")
@ -228,7 +229,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
if optional {
continue
}
return nil, fmt.Errorf("Param %s is required", f.Name)
return nil, badRequest(fmt.Sprintf("Param %s is required", f.Name))
}
if field.Kind() == reflect.Ptr {
// create the new instance of whatever it is
@ -241,7 +242,7 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
case reflect.Uint:
queryValUint, err := strconv.Atoi(queryVal)
if err != nil {
return nil, fmt.Errorf("parsing uint from query: %w", err)
return nil, badRequestErr("parsing uint from query: %w", err)
}
field.SetUint(uint64(queryValUint))
case reflect.Bool:
@ -258,13 +259,12 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
case "":
queryValInt = int(fleet.OrderAscending)
default:
return fleet.ListOptions{},
errors.New("unknown order_direction: " + queryVal)
return fleet.ListOptions{}, badRequest("unknown order_direction: " + queryVal)
}
default:
queryValInt, err = strconv.Atoi(queryVal)
if err != nil {
return nil, fmt.Errorf("parsing int from query: %w", err)
return nil, badRequestErr("parsing int from query: %w", err)
}
}
field.SetInt(int64(queryValInt))
@ -278,6 +278,19 @@ func makeDecoder(iface interface{}) kithttp.DecodeRequestFunc {
}
}
func badRequest(msg string) error {
return &fleet.BadRequestError{Message: msg}
}
func badRequestErr(msg string, err error) error {
// ensure timeout errors don't become BadRequestErrors.
var opErr *net.OpError
if errors.As(err, &opErr) {
return fmt.Errorf(msg, err)
}
return &fleet.BadRequestError{Message: fmt.Errorf(msg, err).Error()}
}
type authEndpointer struct {
svc fleet.Service
opts []kithttp.ServerOption

View File

@ -255,7 +255,7 @@ func (s *integrationLoggerTestSuite) TestSubmitLog() {
assert.Equal(t, 1, strings.Count(logString, "x_for_ip_addr"))
// submit same payload without specifying gzip encoding fails
s.DoRawWithHeaders("POST", "/api/osquery/log", body.Bytes(), http.StatusInternalServerError, nil)
s.DoRawWithHeaders("POST", "/api/osquery/log", body.Bytes(), http.StatusBadRequest, nil)
}
func (s *integrationLoggerTestSuite) TestEnrollAgentLogsErrors() {