THRIFT-2854 Go Struct writer and reader looses important error information

Client: Go
Patch: Chi Vinh Le <cvl@chinet.info>

This closes #291

Fixes error reporting in go generator
This commit is contained in:
Jens Geyer 2014-11-30 15:07:18 +01:00
parent 200150d330
commit 527b6d9c86
6 changed files with 612 additions and 31 deletions

View File

@ -1234,14 +1234,14 @@ void t_go_generator::generate_go_struct_reader(ofstream& out,
<< endl;
indent_up();
out << indent() << "if _, err := iprot.ReadStructBegin(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"%T read error: %s\", p, err)" << endl << indent() << "}" << endl;
<< " return thrift.PrependError(fmt.Sprintf(\"%T read error: \", p), err)" << endl << indent() << "}" << endl;
// Loop over reading in fields
indent(out) << "for {" << endl;
indent_up();
// Read beginning field marker
out << indent() << "_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin()" << endl << indent()
<< "if err != nil {" << endl << indent()
<< " return fmt.Errorf(\"%T field %d read error: %s\", p, fieldId, err)" << endl << indent()
<< " return thrift.PrependError(fmt.Sprintf(\"%T field %d read error: \", p, fieldId), err)" << endl << indent()
<< "}" << endl;
// Check for field STOP marker and break
out << indent() << "if fieldTypeId == thrift.STOP { break; }" << endl;
@ -1302,7 +1302,7 @@ void t_go_generator::generate_go_struct_reader(ofstream& out,
<< " return err" << endl << indent() << "}" << endl;
indent_down();
out << indent() << "}" << endl << indent() << "if err := iprot.ReadStructEnd(); err != nil {"
<< endl << indent() << " return fmt.Errorf(\"%T read struct end error: %s\", p, err)" << endl
<< endl << indent() << " return thrift.PrependError(fmt.Sprintf(\"%T read struct end error: \", p), err)" << endl
<< indent() << "}" << endl << indent() << "return nil" << endl;
indent_down();
out << indent() << "}" << endl << endl;
@ -1338,7 +1338,7 @@ void t_go_generator::generate_go_struct_writer(ofstream& out,
indent(out) << "func (p *" << tstruct_name << ") Write(oprot thrift.TProtocol) error {" << endl;
indent_up();
out << indent() << "if err := oprot.WriteStructBegin(\"" << name << "\"); err != nil {" << endl
<< indent() << " return fmt.Errorf(\"%T write struct begin error: %s\", p, err) }" << endl;
<< indent() << " return thrift.PrependError(fmt.Sprintf(\"%T write struct begin error: \", p), err) }" << endl;
string field_name;
string escape_field_name;
@ -1363,9 +1363,9 @@ void t_go_generator::generate_go_struct_writer(ofstream& out,
// Write the struct map
out << indent() << "if err := oprot.WriteFieldStop(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"write field stop error: %s\", err) }" << endl << indent()
<< " return thrift.PrependError(\"write field stop error: \", err) }" << endl << indent()
<< "if err := oprot.WriteStructEnd(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"write struct stop error: %s\", err) }" << endl << indent()
<< " return thrift.PrependError(\"write struct stop error: \", err) }" << endl << indent()
<< "return nil" << endl;
indent_down();
out << indent() << "}" << endl << endl;
@ -1395,16 +1395,16 @@ void t_go_generator::generate_go_struct_writer(ofstream& out,
out << indent() << "if err := oprot.WriteFieldBegin(\"" << escape_field_name << "\", "
<< type_to_enum((*f_iter)->get_type()) << ", " << field_id << "); err != nil {" << endl
<< indent() << " return fmt.Errorf(\"%T write field begin error " << field_id << ":"
<< escape_field_name << ": %s\", p, err); }" << endl;
<< indent() << " return thrift.PrependError(fmt.Sprintf(\"%T write field begin error " << field_id << ":"
<< escape_field_name << ": \", p), err) }" << endl;
// Write field contents
generate_serialize_field(out, *f_iter, "p.");
// Write field closer
out << indent() << "if err := oprot.WriteFieldEnd(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"%T write field end error " << field_id << ":" << escape_field_name
<< ": %s\", p, err); }" << endl;
<< " return thrift.PrependError(fmt.Sprintf(\"%T write field end error " << field_id << ":" << escape_field_name
<< ": \", p), err) }" << endl;
if (field_required == t_field::T_OPTIONAL) {
indent_down();
@ -2450,8 +2450,8 @@ void t_go_generator::generate_deserialize_field(ofstream& out,
out << "ReadI32()";
}
out << "; err != nil {" << endl << indent() << "return fmt.Errorf(\"error reading field "
<< tfield->get_key() << ": %s\", err)" << endl;
out << "; err != nil {" << endl << indent() << "return thrift.PrependError(\"error reading field "
<< tfield->get_key() << ": \", err)" << endl;
out << "} else {" << endl;
string wrap;
@ -2490,7 +2490,7 @@ void t_go_generator::generate_deserialize_struct(ofstream& out,
out << indent() << prefix << eq << (pointer_field ? "&" : "");
generate_go_struct_initializer(out, tstruct);
out << indent() << "if err := " << prefix << ".Read(iprot); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"%T error reading struct: %s\", " << prefix << ", err)" << endl
<< " return thrift.PrependError(fmt.Sprintf(\"%T error reading struct: \", " << prefix << "), err)" << endl
<< indent() << "}" << endl;
}
@ -2514,21 +2514,21 @@ void t_go_generator::generate_deserialize_container(ofstream& out,
if (ttype->is_map()) {
out << indent() << "_, _, size, err := iprot.ReadMapBegin()" << endl << indent()
<< "if err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error reading map begin: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error reading map begin: \", err)" << endl << indent() << "}"
<< endl << indent() << "tMap := make(" << type_to_go_type(orig_type) << ", size)" << endl
<< indent() << prefix << eq << " " << (pointer_field ? "&" : "") << "tMap" << endl;
} else if (ttype->is_set()) {
t_set* t = (t_set*)ttype;
out << indent() << "_, size, err := iprot.ReadSetBegin()" << endl << indent()
<< "if err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error reading set begin: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error reading set begin: \", err)" << endl << indent() << "}"
<< endl << indent() << "tSet := make(map["
<< type_to_go_key_type(t->get_elem_type()->get_true_type()) << "]bool, size)" << endl
<< indent() << prefix << eq << " " << (pointer_field ? "&" : "") << "tSet" << endl;
} else if (ttype->is_list()) {
out << indent() << "_, size, err := iprot.ReadListBegin()" << endl << indent()
<< "if err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error reading list begin: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error reading list begin: \", err)" << endl << indent() << "}"
<< endl << indent() << "tSlice := make(" << type_to_go_type(orig_type) << ", 0, size)"
<< endl << indent() << prefix << eq << " " << (pointer_field ? "&" : "") << "tSlice"
<< endl;
@ -2558,15 +2558,15 @@ void t_go_generator::generate_deserialize_container(ofstream& out,
// Read container end
if (ttype->is_map()) {
out << indent() << "if err := iprot.ReadMapEnd(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error reading map end: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error reading map end: \", err)" << endl << indent() << "}"
<< endl;
} else if (ttype->is_set()) {
out << indent() << "if err := iprot.ReadSetEnd(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error reading set end: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error reading set end: \", err)" << endl << indent() << "}"
<< endl;
} else if (ttype->is_list()) {
out << indent() << "if err := iprot.ReadListEnd(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error reading list end: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error reading list end: \", err)" << endl << indent() << "}"
<< endl;
}
}
@ -2697,9 +2697,9 @@ void t_go_generator::generate_serialize_field(ofstream& out,
out << "WriteI32(int32(" << name << "))";
}
out << "; err != nil {" << endl << indent() << "return fmt.Errorf(\"%T."
out << "; err != nil {" << endl << indent() << "return thrift.PrependError(fmt.Sprintf(\"%T."
<< escape_string(tfield->get_name()) << " (" << tfield->get_key()
<< ") field write error: %s\", p, err) }" << endl;
<< ") field write error: \", p), err) }" << endl;
} else {
throw "compiler error: Invalid type in generate_serialize_field '" + type->get_name()
+ "' for field '" + name + "'";
@ -2715,7 +2715,7 @@ void t_go_generator::generate_serialize_field(ofstream& out,
void t_go_generator::generate_serialize_struct(ofstream& out, t_struct* tstruct, string prefix) {
(void)tstruct;
out << indent() << "if err := " << prefix << ".Write(oprot); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"%T error writing struct: %s\", " << prefix << ", err)" << endl
<< " return thrift.PrependError(fmt.Sprintf(\"%T error writing struct: \", " << prefix << "), err)" << endl
<< indent() << "}" << endl;
}
@ -2731,19 +2731,19 @@ void t_go_generator::generate_serialize_container(ofstream& out,
<< type_to_enum(((t_map*)ttype)->get_key_type()) << ", "
<< type_to_enum(((t_map*)ttype)->get_val_type()) << ", "
<< "len(" << prefix << ")); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error writing map begin: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error writing map begin: \", err)" << endl << indent() << "}"
<< endl;
} else if (ttype->is_set()) {
out << indent() << "if err := oprot.WriteSetBegin("
<< type_to_enum(((t_set*)ttype)->get_elem_type()) << ", "
<< "len(" << prefix << ")); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error writing set begin: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error writing set begin: \", err)" << endl << indent() << "}"
<< endl;
} else if (ttype->is_list()) {
out << indent() << "if err := oprot.WriteListBegin("
<< type_to_enum(((t_list*)ttype)->get_elem_type()) << ", "
<< "len(" << prefix << ")); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error writing list begin: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error writing list begin: \", err)" << endl << indent() << "}"
<< endl;
} else {
throw "compiler error: Invalid type in generate_serialize_container '" + ttype->get_name()
@ -2776,15 +2776,15 @@ void t_go_generator::generate_serialize_container(ofstream& out,
if (ttype->is_map()) {
out << indent() << "if err := oprot.WriteMapEnd(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error writing map end: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error writing map end: \", err)" << endl << indent() << "}"
<< endl;
} else if (ttype->is_set()) {
out << indent() << "if err := oprot.WriteSetEnd(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error writing set end: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error writing set end: \", err)" << endl << indent() << "}"
<< endl;
} else if (ttype->is_list()) {
out << indent() << "if err := oprot.WriteListEnd(); err != nil {" << endl << indent()
<< " return fmt.Errorf(\"error writing list end: %s\", err)" << endl << indent() << "}"
<< " return thrift.PrependError(\"error writing list end: \", err)" << endl << indent() << "}"
<< endl;
}
}

View File

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
* Contains some contributions under the Thrift Software License.
* Please see doc/old-thrift-license.txt in the Thrift distribution for
* details.
*/
struct TestStruct
{
1: map<string, string> m,
2: list<string> l,
3: set<string> s,
4: i32 i
}
service ErrorTest
{
TestStruct testStruct(1: TestStruct thing)
}

View File

@ -30,7 +30,8 @@ gopath: $(top_srcdir)/compiler/cpp/thrift $(THRIFTTEST) \
ServicesTest.thrift \
GoTagTest.thrift \
TypedefFieldTest.thrift \
RefAnnotationFieldsTest.thrift
RefAnnotationFieldsTest.thrift \
ErrorTest.thrift
mkdir -p gopath/src
grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
$(THRIFT) -r IncludesTest.thrift
@ -42,6 +43,7 @@ gopath: $(top_srcdir)/compiler/cpp/thrift $(THRIFTTEST) \
$(THRIFT) GoTagTest.thrift
$(THRIFT) TypedefFieldTest.thrift
$(THRIFT) RefAnnotationFieldsTest.thrift
$(THRIFT) ErrorTest.thrift
GOPATH=`pwd`/gopath $(GO) get code.google.com/p/gomock/gomock
ln -nfs ../../../thrift gopath/src/thrift
ln -nfs ../../tests gopath/src/tests
@ -53,7 +55,8 @@ check: gopath
binarykeytest \
servicestest \
typedeffieldtest \
refannotationfieldstest
refannotationfieldstest \
errortest
GOPATH=`pwd`/gopath $(GO) test thrift tests
clean-local:
@ -73,4 +76,5 @@ EXTRA_DIST = \
OptionalFieldsTest.thrift \
RefAnnotationFieldsTest.thrift \
ServicesTest.thrift \
TypedefFieldTest.thrift
TypedefFieldTest.thrift \
ErrorTest.thrift

View File

@ -0,0 +1,455 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package tests
import (
"code.google.com/p/gomock/gomock"
"errors"
"errortest"
"testing"
"thrift"
)
// Setup mock to fail at a certain position. Return true if position exists otherwise false.
func prepareClientProtocolFailure(protocol *MockTProtocol, failAt int, failWith error) bool {
var err error = nil
if failAt == 0 {
err = failWith
}
last := protocol.EXPECT().WriteMessageBegin("testStruct", thrift.TMessageType(1), int32(1)).Return(err)
if failAt == 0 {
return true
}
if failAt == 1 {
err = failWith
}
last = protocol.EXPECT().WriteStructBegin("testStruct_args").Return(err).After(last)
if failAt == 1 {
return true
}
if failAt == 2 {
err = failWith
}
last = protocol.EXPECT().WriteFieldBegin("thing", thrift.TType(thrift.STRUCT), int16(1)).Return(err).After(last)
if failAt == 2 {
return true
}
if failAt == 3 {
err = failWith
}
last = protocol.EXPECT().WriteStructBegin("TestStruct").Return(err).After(last)
if failAt == 3 {
return true
}
if failAt == 4 {
err = failWith
}
last = protocol.EXPECT().WriteFieldBegin("m", thrift.TType(thrift.MAP), int16(1)).Return(err).After(last)
if failAt == 4 {
return true
}
if failAt == 5 {
err = failWith
}
last = protocol.EXPECT().WriteMapBegin(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0).Return(err).After(last)
if failAt == 5 {
return true
}
if failAt == 6 {
err = failWith
}
last = protocol.EXPECT().WriteMapEnd().Return(err).After(last)
if failAt == 6 {
return true
}
if failAt == 7 {
err = failWith
}
last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
if failAt == 7 {
return true
}
if failAt == 8 {
err = failWith
}
last = protocol.EXPECT().WriteFieldBegin("l", thrift.TType(thrift.LIST), int16(2)).Return(err).After(last)
if failAt == 8 {
return true
}
if failAt == 9 {
err = failWith
}
last = protocol.EXPECT().WriteListBegin(thrift.TType(thrift.STRING), 0).Return(err).After(last)
if failAt == 9 {
return true
}
if failAt == 10 {
err = failWith
}
last = protocol.EXPECT().WriteListEnd().Return(err).After(last)
if failAt == 10 {
return true
}
if failAt == 11 {
err = failWith
}
last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
if failAt == 11 {
return true
}
if failAt == 12 {
err = failWith
}
last = protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.SET), int16(3)).Return(err).After(last)
if failAt == 12 {
return true
}
if failAt == 13 {
err = failWith
}
last = protocol.EXPECT().WriteSetBegin(thrift.TType(thrift.STRING), 0).Return(err).After(last)
if failAt == 13 {
return true
}
if failAt == 14 {
err = failWith
}
last = protocol.EXPECT().WriteSetEnd().Return(err).After(last)
if failAt == 14 {
return true
}
if failAt == 15 {
err = failWith
}
last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
if failAt == 15 {
return true
}
if failAt == 16 {
err = failWith
}
last = protocol.EXPECT().WriteFieldBegin("i", thrift.TType(thrift.I32), int16(4)).Return(err).After(last)
if failAt == 16 {
return true
}
if failAt == 17 {
err = failWith
}
last = protocol.EXPECT().WriteI32(int32(3)).Return(err).After(last)
if failAt == 17 {
return true
}
if failAt == 18 {
err = failWith
}
last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
if failAt == 18 {
return true
}
if failAt == 19 {
err = failWith
}
last = protocol.EXPECT().WriteFieldStop().Return(err).After(last)
if failAt == 19 {
return true
}
if failAt == 20 {
err = failWith
}
last = protocol.EXPECT().WriteStructEnd().Return(err).After(last)
if failAt == 20 {
return true
}
if failAt == 21 {
err = failWith
}
last = protocol.EXPECT().WriteFieldEnd().Return(err).After(last)
if failAt == 21 {
return true
}
if failAt == 22 {
err = failWith
}
last = protocol.EXPECT().WriteFieldStop().Return(err).After(last)
if failAt == 22 {
return true
}
if failAt == 23 {
err = failWith
}
last = protocol.EXPECT().WriteStructEnd().Return(err).After(last)
if failAt == 23 {
return true
}
if failAt == 24 {
err = failWith
}
last = protocol.EXPECT().WriteMessageEnd().Return(err).After(last)
if failAt == 24 {
return true
}
if failAt == 25 {
err = failWith
}
last = protocol.EXPECT().Flush().Return(err).After(last)
if failAt == 25 {
return true
}
if failAt == 26 {
err = failWith
}
last = protocol.EXPECT().ReadMessageBegin().Return("testStruct", thrift.TMessageType(1), int32(1), err).After(last)
if failAt == 26 {
return true
}
if failAt == 27 {
err = failWith
}
last = protocol.EXPECT().ReadStructBegin().Return("testStruct_args", err).After(last)
if failAt == 27 {
return true
}
if failAt == 28 {
err = failWith
}
last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STRUCT), int16(0), err).After(last)
if failAt == 28 {
return true
}
if failAt == 29 {
err = failWith
}
last = protocol.EXPECT().ReadStructBegin().Return("TestStruct", err).After(last)
if failAt == 29 {
return true
}
if failAt == 30 {
err = failWith
}
last = protocol.EXPECT().ReadFieldBegin().Return("m", thrift.TType(thrift.MAP), int16(1), err).After(last)
if failAt == 30 {
return true
}
if failAt == 31 {
err = failWith
}
last = protocol.EXPECT().ReadMapBegin().Return(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0, err).After(last)
if failAt == 31 {
return true
}
if failAt == 32 {
err = failWith
}
last = protocol.EXPECT().ReadMapEnd().Return(err).After(last)
if failAt == 32 {
return true
}
if failAt == 33 {
err = failWith
}
last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
if failAt == 33 {
return true
}
if failAt == 34 {
err = failWith
}
last = protocol.EXPECT().ReadFieldBegin().Return("l", thrift.TType(thrift.LIST), int16(2), err).After(last)
if failAt == 34 {
return true
}
if failAt == 35 {
err = failWith
}
last = protocol.EXPECT().ReadListBegin().Return(thrift.TType(thrift.STRING), 0, err).After(last)
if failAt == 35 {
return true
}
if failAt == 36 {
err = failWith
}
last = protocol.EXPECT().ReadListEnd().Return(err).After(last)
if failAt == 36 {
return true
}
if failAt == 37 {
err = failWith
}
last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
if failAt == 37 {
return true
}
if failAt == 38 {
err = failWith
}
last = protocol.EXPECT().ReadFieldBegin().Return("s", thrift.TType(thrift.SET), int16(3), err).After(last)
if failAt == 38 {
return true
}
if failAt == 39 {
err = failWith
}
last = protocol.EXPECT().ReadSetBegin().Return(thrift.TType(thrift.STRING), 0, err).After(last)
if failAt == 39 {
return true
}
if failAt == 40 {
err = failWith
}
last = protocol.EXPECT().ReadSetEnd().Return(err).After(last)
if failAt == 40 {
return true
}
if failAt == 41 {
err = failWith
}
last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
if failAt == 41 {
return true
}
if failAt == 42 {
err = failWith
}
last = protocol.EXPECT().ReadFieldBegin().Return("i", thrift.TType(thrift.I32), int16(4), err).After(last)
if failAt == 42 {
return true
}
if failAt == 43 {
err = failWith
}
last = protocol.EXPECT().ReadI32().Return(int32(3), err).After(last)
if failAt == 43 {
return true
}
if failAt == 44 {
err = failWith
}
last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
if failAt == 44 {
return true
}
if failAt == 45 {
err = failWith
}
last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(5), err).After(last)
if failAt == 45 {
return true
}
if failAt == 46 {
err = failWith
}
last = protocol.EXPECT().ReadStructEnd().Return(err).After(last)
if failAt == 46 {
return true
}
if failAt == 47 {
err = failWith
}
last = protocol.EXPECT().ReadFieldEnd().Return(err).After(last)
if failAt == 47 {
return true
}
if failAt == 48 {
err = failWith
}
last = protocol.EXPECT().ReadFieldBegin().Return("_", thrift.TType(thrift.STOP), int16(1), err).After(last)
if failAt == 48 {
return true
}
if failAt == 49 {
err = failWith
}
last = protocol.EXPECT().ReadStructEnd().Return(err).After(last)
if failAt == 49 {
return true
}
if failAt == 50 {
err = failWith
}
last = protocol.EXPECT().ReadMessageEnd().Return(err).After(last)
if failAt == 50 {
return true
}
return false
}
func TestClientReportTTransportErrors(t *testing.T) {
mockCtrl := gomock.NewController(t)
transport := thrift.NewTMemoryBuffer()
thing := errortest.NewTestStruct()
thing.M = make(map[string]string)
thing.L = make([]string, 0)
thing.S = make(map[string]bool)
thing.I = 3
err := thrift.NewTTransportException(thrift.TIMED_OUT, "test")
for i := 0; ; i++ {
protocol := NewMockTProtocol(mockCtrl)
if !prepareClientProtocolFailure(protocol, i, err) {
return
}
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
_, retErr := client.TestStruct(thing)
err2, ok := retErr.(thrift.TTransportException)
if !ok {
t.Fatal("Expected a TTrasportException")
}
if err2.TypeId() != err.TypeId() {
t.Fatal("Expected a same error type id")
}
mockCtrl.Finish()
}
}
func TestClientReportTProtocolErrors(t *testing.T) {
mockCtrl := gomock.NewController(t)
transport := thrift.NewTMemoryBuffer()
thing := errortest.NewTestStruct()
thing.M = make(map[string]string)
thing.L = make([]string, 0)
thing.S = make(map[string]bool)
thing.I = 3
err := thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, errors.New("test"))
for i := 0; ; i++ {
protocol := NewMockTProtocol(mockCtrl)
if !prepareClientProtocolFailure(protocol, i, err) {
return
}
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
_, retErr := client.TestStruct(thing)
err2, ok := retErr.(thrift.TProtocolException)
if !ok {
t.Fatal("Expected a TProtocolException")
}
if err2.TypeId() != err.TypeId() {
t.Fatal("Expected a same error type id")
}
mockCtrl.Finish()
}
}

View File

@ -19,7 +19,26 @@
package thrift
import (
"errors"
)
// Generic Thrift exception
type TException interface {
error
}
// Prepends additional information to an error without loosing the Thrift exception interface
func PrependError(prepend string, err error) error {
if t, ok := err.(TTransportException); ok {
return NewTTransportException(t.TypeId(), prepend+t.Error())
}
if t, ok := err.(TProtocolException); ok {
return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error()))
}
if t, ok := err.(TApplicationException); ok {
return NewTApplicationException(t.TypeId(), prepend+t.Error())
}
return errors.New(prepend + err.Error())
}

View File

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"errors"
"testing"
)
func TestPrependError(t *testing.T) {
err := NewTApplicationException(INTERNAL_ERROR, "original error")
err2, ok := PrependError("Prepend: ", err).(TApplicationException)
if !ok {
t.Fatal("Couldn't cast error TApplicationException")
}
if err2.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err2.TypeId() != INTERNAL_ERROR {
t.Fatal("Unexpected type error")
}
err3 := NewTProtocolExceptionWithType(INVALID_DATA, errors.New("original error"))
err4, ok := PrependError("Prepend: ", err3).(TProtocolException)
if !ok {
t.Fatal("Couldn't cast error TProtocolException")
}
if err4.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err4.TypeId() != INVALID_DATA {
t.Fatal("Unexpected type error")
}
err5 := NewTTransportException(TIMED_OUT, "original error")
err6, ok := PrependError("Prepend: ", err5).(TTransportException)
if !ok {
t.Fatal("Couldn't cast error TTransportException")
}
if err6.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err6.TypeId() != TIMED_OUT {
t.Fatal("Unexpected type error")
}
err7 := errors.New("original error")
err8 := PrependError("Prepend: ", err7)
if err8.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
}