diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index b5742f6ec..0807efbe9 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -1591,17 +1591,19 @@ void t_go_generator::generate_go_struct_reader(ostream& out, << endl; out << indent() << " return err" << endl; out << indent() << " }" << endl; + + // Mark required field as read + if ((*f_iter)->get_req() == t_field::T_REQUIRED) { + const string field_name(publicize(escape_string((*f_iter)->get_name()))); + out << indent() << " isset" << field_name << " = true" << endl; + } + out << indent() << "} else {" << endl; out << indent() << " if err := iprot.Skip(fieldTypeId); err != nil {" << endl; out << indent() << " return err" << endl; out << indent() << " }" << endl; out << indent() << "}" << endl; - // Mark required field as read - if ((*f_iter)->get_req() == t_field::T_REQUIRED) { - const string field_name(publicize(escape_string((*f_iter)->get_name()))); - out << indent() << "isset" << field_name << " = true" << endl; - } indent_down(); } diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am index b7ba87085..78d468129 100644 --- a/lib/go/test/Makefile.am +++ b/lib/go/test/Makefile.am @@ -27,6 +27,7 @@ gopath: $(THRIFT) $(THRIFTTEST) \ MultiplexedProtocolTest.thrift \ OnewayTest.thrift \ OptionalFieldsTest.thrift \ + RequiredFieldTest.thrift \ ServicesTest.thrift \ GoTagTest.thrift \ TypedefFieldTest.thrift \ @@ -46,6 +47,7 @@ gopath: $(THRIFT) $(THRIFTTEST) \ $(THRIFT) $(THRIFTARGS) MultiplexedProtocolTest.thrift $(THRIFT) $(THRIFTARGS) OnewayTest.thrift $(THRIFT) $(THRIFTARGS) OptionalFieldsTest.thrift + $(THRIFT) $(THRIFTARGS) RequiredFieldTest.thrift $(THRIFT) $(THRIFTARGS) ServicesTest.thrift $(THRIFT) $(THRIFTARGS) GoTagTest.thrift $(THRIFT) $(THRIFTARGS) TypedefFieldTest.thrift @@ -96,6 +98,7 @@ EXTRA_DIST = \ NamespacedTest.thrift \ OnewayTest.thrift \ OptionalFieldsTest.thrift \ + RequiredFieldTest.thrift \ RefAnnotationFieldsTest.thrift \ UnionDefaultValueTest.thrift \ UnionBinaryTest.thrift \ diff --git a/lib/go/test/RequiredFieldTest.thrift b/lib/go/test/RequiredFieldTest.thrift new file mode 100644 index 000000000..4a2dcaeba --- /dev/null +++ b/lib/go/test/RequiredFieldTest.thrift @@ -0,0 +1,7 @@ +struct RequiredField { + 1: required string name +} + +struct OtherThing { + 1: required i16 value +} diff --git a/lib/go/test/tests/required_fields_test.go b/lib/go/test/tests/required_fields_test.go index 287ef60c3..3fa414ad8 100644 --- a/lib/go/test/tests/required_fields_test.go +++ b/lib/go/test/tests/required_fields_test.go @@ -20,12 +20,45 @@ package tests import ( + "context" "github.com/golang/mock/gomock" "optionalfieldstest" + "requiredfieldtest" "testing" "thrift" ) +func TestRequiredField_SucecssWhenSet(t *testing.T) { + // create a new RequiredField instance with the required field set + source := &requiredfieldtest.RequiredField{Name: "this is a test"} + sourceData, err := thrift.NewTSerializer().Write(context.Background(), source) + if err != nil { + t.Fatalf("failed to serialize %T: %v", source, err) + } + + d := thrift.NewTDeserializer() + err = d.Read(&requiredfieldtest.RequiredField{}, sourceData) + if err != nil { + t.Fatalf("Did not expect an error when trying to deserialize the requiredfieldtest.RequiredField: %v", err) + } +} + +func TestRequiredField_ErrorWhenMissing(t *testing.T) { + // create a new OtherThing instance, without setting the required field + source := &requiredfieldtest.OtherThing{} + sourceData, err := thrift.NewTSerializer().Write(context.Background(), source) + if err != nil { + t.Fatalf("failed to serialize %T: %v", source, err) + } + + // attempt to deserialize into a different type (which should fail) + d := thrift.NewTDeserializer() + err = d.Read(&requiredfieldtest.RequiredField{}, sourceData) + if err == nil { + t.Fatal("Expected an error when trying to deserialize an object which is missing a required field") + } +} + func TestStructReadRequiredFields(t *testing.T) { mockCtrl := gomock.NewController(t) protocol := NewMockTProtocol(mockCtrl)