From 9e864d57026b4905862108418ba9482892fb1f65 Mon Sep 17 00:00:00 2001 From: Yuxuan 'fishy' Wang Date: Fri, 17 Apr 2020 11:53:39 -0700 Subject: [PATCH] Add UnsetHeader to go library Client: go We already have SetHeader and GetHeader helper functions in the go library to deal with THeader injected into the context object. But we didn't provide a way to unset/delete a key from the context object. This will be useful with the TSimpleServer.SetForwardHeaders API. In the scenario that a thrift server want to auto forward certain headers to other upstream thrift servers as the fallback, but during the handling of the request might decide to remove some of the auto forward headers. This is also achievable through mutate the write header list, but since that's a list, finding one key from the list and remove it is much more hassle. --- lib/go/thrift/header_context.go | 9 +++++++++ lib/go/thrift/header_context_test.go | 13 ++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/go/thrift/header_context.go b/lib/go/thrift/header_context.go index 21e880d66..ac9bd4882 100644 --- a/lib/go/thrift/header_context.go +++ b/lib/go/thrift/header_context.go @@ -44,6 +44,15 @@ func SetHeader(ctx context.Context, key, value string) context.Context { ) } +// UnsetHeader unsets a previously set header in the context. +func UnsetHeader(ctx context.Context, key string) context.Context { + return context.WithValue( + ctx, + headerKey(key), + nil, + ) +} + // GetHeader returns a value of the given header from the context. func GetHeader(ctx context.Context, key string) (value string, ok bool) { if v := ctx.Value(headerKey(key)); v != nil { diff --git a/lib/go/thrift/header_context_test.go b/lib/go/thrift/header_context_test.go index a1ea2d093..16b10c858 100644 --- a/lib/go/thrift/header_context_test.go +++ b/lib/go/thrift/header_context_test.go @@ -25,7 +25,7 @@ import ( "testing" ) -func TestSetGetHeader(t *testing.T) { +func TestSetGetUnsetHeader(t *testing.T) { const ( key = "foo" value = "bar" @@ -68,6 +68,17 @@ func TestSetGetHeader(t *testing.T) { } }, ) + + t.Run( + "Unset", + func(t *testing.T) { + ctx := UnsetHeader(ctx, key) + + if _, ok := GetHeader(ctx, key); ok { + t.Errorf("GetHeader returned ok on unset key %q", key) + } + }, + ) } func TestReadKeyList(t *testing.T) {