THRIFT-3235 C#: Limit recursion depth to 64

Client: C#
Patch: Jens Geyer
This commit is contained in:
Jens Geyer 2015-07-09 23:02:46 +02:00
parent d47fcdd66d
commit 4018036980
3 changed files with 148 additions and 64 deletions

View File

@ -909,6 +909,10 @@ void t_csharp_generator::generate_csharp_struct_reader(ofstream& out, t_struct*
indent(out) << "public void Read (TProtocol iprot)" << endl; indent(out) << "public void Read (TProtocol iprot)" << endl;
scope_up(out); scope_up(out);
out << indent() << "iprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
const vector<t_field*>& fields = tstruct->get_members(); const vector<t_field*>& fields = tstruct->get_members();
vector<t_field*>::const_iterator f_iter; vector<t_field*>::const_iterator f_iter;
@ -977,6 +981,12 @@ void t_csharp_generator::generate_csharp_struct_reader(ofstream& out, t_struct*
} }
} }
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "iprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent_down(); indent_down();
indent(out) << "}" << endl << endl; indent(out) << "}" << endl << endl;
@ -985,6 +995,10 @@ void t_csharp_generator::generate_csharp_struct_reader(ofstream& out, t_struct*
void t_csharp_generator::generate_csharp_struct_writer(ofstream& out, t_struct* tstruct) { void t_csharp_generator::generate_csharp_struct_writer(ofstream& out, t_struct* tstruct) {
out << indent() << "public void Write(TProtocol oprot) {" << endl; out << indent() << "public void Write(TProtocol oprot) {" << endl;
indent_up(); indent_up();
out << indent() << "oprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
string name = tstruct->get_name(); string name = tstruct->get_name();
const vector<t_field*>& fields = tstruct->get_sorted_members(); const vector<t_field*>& fields = tstruct->get_sorted_members();
@ -1030,8 +1044,14 @@ void t_csharp_generator::generate_csharp_struct_writer(ofstream& out, t_struct*
indent(out) << "oprot.WriteFieldStop();" << endl; indent(out) << "oprot.WriteFieldStop();" << endl;
indent(out) << "oprot.WriteStructEnd();" << endl; indent(out) << "oprot.WriteStructEnd();" << endl;
indent_down(); scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "oprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent_down();
indent(out) << "}" << endl << endl; indent(out) << "}" << endl << endl;
} }
@ -1039,6 +1059,10 @@ void t_csharp_generator::generate_csharp_struct_result_writer(ofstream& out, t_s
indent(out) << "public void Write(TProtocol oprot) {" << endl; indent(out) << "public void Write(TProtocol oprot) {" << endl;
indent_up(); indent_up();
out << indent() << "oprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
string name = tstruct->get_name(); string name = tstruct->get_name();
const vector<t_field*>& fields = tstruct->get_sorted_members(); const vector<t_field*>& fields = tstruct->get_sorted_members();
vector<t_field*>::const_iterator f_iter; vector<t_field*>::const_iterator f_iter;
@ -1092,6 +1116,12 @@ void t_csharp_generator::generate_csharp_struct_result_writer(ofstream& out, t_s
out << endl << indent() << "oprot.WriteFieldStop();" << endl << indent() out << endl << indent() << "oprot.WriteFieldStop();" << endl << indent()
<< "oprot.WriteStructEnd();" << endl; << "oprot.WriteStructEnd();" << endl;
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "oprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent_down(); indent_down();
indent(out) << "}" << endl << endl; indent(out) << "}" << endl << endl;
@ -1249,6 +1279,11 @@ void t_csharp_generator::generate_csharp_union_class(std::ofstream& out,
indent(out) << "}" << endl; indent(out) << "}" << endl;
indent(out) << "public override void Write(TProtocol oprot) {" << endl; indent(out) << "public override void Write(TProtocol oprot) {" << endl;
indent_up(); indent_up();
out << indent() << "oprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
indent(out) << "TStruct struc = new TStruct(\"" << tunion->get_name() << "\");" << endl; indent(out) << "TStruct struc = new TStruct(\"" << tunion->get_name() << "\");" << endl;
indent(out) << "oprot.WriteStructBegin(struc);" << endl; indent(out) << "oprot.WriteStructBegin(struc);" << endl;
@ -1264,6 +1299,13 @@ void t_csharp_generator::generate_csharp_union_class(std::ofstream& out,
indent(out) << "oprot.WriteFieldStop();" << endl; indent(out) << "oprot.WriteFieldStop();" << endl;
indent(out) << "oprot.WriteStructEnd();" << endl; indent(out) << "oprot.WriteStructEnd();" << endl;
indent_down(); indent_down();
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "oprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent(out) << "}" << endl; indent(out) << "}" << endl;
indent_down(); indent_down();
@ -1987,6 +2029,11 @@ void t_csharp_generator::generate_csharp_union_reader(std::ofstream& out, t_stru
indent(out) << "public static " << tunion->get_name() << " Read(TProtocol iprot)" << endl; indent(out) << "public static " << tunion->get_name() << " Read(TProtocol iprot)" << endl;
scope_up(out); scope_up(out);
out << indent() << "iprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
indent(out) << tunion->get_name() << " retval;" << endl; indent(out) << tunion->get_name() << " retval;" << endl;
indent(out) << "iprot.ReadStructBegin();" << endl; indent(out) << "iprot.ReadStructBegin();" << endl;
indent(out) << "TField field = iprot.ReadFieldBegin();" << endl; indent(out) << "TField field = iprot.ReadFieldBegin();" << endl;
@ -2036,13 +2083,16 @@ void t_csharp_generator::generate_csharp_union_reader(std::ofstream& out, t_stru
// end of else for TStop // end of else for TStop
scope_down(out); scope_down(out);
indent(out) << "iprot.ReadStructEnd();" << endl; indent(out) << "iprot.ReadStructEnd();" << endl;
indent(out) << "return retval;" << endl; indent(out) << "return retval;" << endl;
indent_down(); indent_down();
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "iprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent(out) << "}" << endl << endl; indent(out) << "}" << endl << endl;
} }

View File

@ -29,11 +29,17 @@ namespace Thrift.Protocol
{ {
public abstract class TProtocol : IDisposable public abstract class TProtocol : IDisposable
{ {
private const int DEFAULT_RECURSION_DEPTH = 64;
protected TTransport trans; protected TTransport trans;
protected int recursionLimit;
protected int recursionDepth;
protected TProtocol(TTransport trans) protected TProtocol(TTransport trans)
{ {
this.trans = trans; this.trans = trans;
this.recursionLimit = DEFAULT_RECURSION_DEPTH;
this.recursionDepth = 0;
} }
public TTransport Transport public TTransport Transport
@ -41,6 +47,25 @@ namespace Thrift.Protocol
get { return trans; } get { return trans; }
} }
public int RecursionLimit
{
get { return recursionLimit; }
set { recursionLimit = value; }
}
public void IncrementRecursionDepth()
{
if (recursionDepth < recursionLimit)
++recursionDepth;
else
throw new TProtocolException(TProtocolException.DEPTH_LIMIT, "Depth limit exceeded");
}
public void DecrementRecursionDepth()
{
--recursionDepth;
}
#region " IDisposable Support " #region " IDisposable Support "
private bool _IsDisposed; private bool _IsDisposed;

View File

@ -29,69 +29,78 @@ namespace Thrift.Protocol
{ {
public static void Skip(TProtocol prot, TType type) public static void Skip(TProtocol prot, TType type)
{ {
switch (type) prot.IncrementRecursionDepth();
try
{ {
case TType.Bool: switch (type)
prot.ReadBool(); {
break; case TType.Bool:
case TType.Byte: prot.ReadBool();
prot.ReadByte(); break;
break; case TType.Byte:
case TType.I16: prot.ReadByte();
prot.ReadI16(); break;
break; case TType.I16:
case TType.I32: prot.ReadI16();
prot.ReadI32(); break;
break; case TType.I32:
case TType.I64: prot.ReadI32();
prot.ReadI64(); break;
break; case TType.I64:
case TType.Double: prot.ReadI64();
prot.ReadDouble(); break;
break; case TType.Double:
case TType.String: prot.ReadDouble();
// Don't try to decode the string, just skip it. break;
prot.ReadBinary(); case TType.String:
break; // Don't try to decode the string, just skip it.
case TType.Struct: prot.ReadBinary();
prot.ReadStructBegin(); break;
while (true) case TType.Struct:
{ prot.ReadStructBegin();
TField field = prot.ReadFieldBegin(); while (true)
if (field.Type == TType.Stop)
{ {
break; TField field = prot.ReadFieldBegin();
if (field.Type == TType.Stop)
{
break;
}
Skip(prot, field.Type);
prot.ReadFieldEnd();
} }
Skip(prot, field.Type); prot.ReadStructEnd();
prot.ReadFieldEnd(); break;
} case TType.Map:
prot.ReadStructEnd(); TMap map = prot.ReadMapBegin();
break; for (int i = 0; i < map.Count; i++)
case TType.Map: {
TMap map = prot.ReadMapBegin(); Skip(prot, map.KeyType);
for (int i = 0; i < map.Count; i++) Skip(prot, map.ValueType);
{ }
Skip(prot, map.KeyType); prot.ReadMapEnd();
Skip(prot, map.ValueType); break;
} case TType.Set:
prot.ReadMapEnd(); TSet set = prot.ReadSetBegin();
break; for (int i = 0; i < set.Count; i++)
case TType.Set: {
TSet set = prot.ReadSetBegin(); Skip(prot, set.ElementType);
for (int i = 0; i < set.Count; i++) }
{ prot.ReadSetEnd();
Skip(prot, set.ElementType); break;
} case TType.List:
prot.ReadSetEnd(); TList list = prot.ReadListBegin();
break; for (int i = 0; i < list.Count; i++)
case TType.List: {
TList list = prot.ReadListBegin(); Skip(prot, list.ElementType);
for (int i = 0; i < list.Count; i++) }
{ prot.ReadListEnd();
Skip(prot, list.ElementType); break;
} }
prot.ReadListEnd();
break; }
finally
{
prot.DecrementRecursionDepth();
} }
} }
} }