THRIFT-3235 C#: Limit recursion depth to 64

Client: C#
Patch: Jens Geyer
This commit is contained in:
Jens Geyer 2015-07-09 23:02:46 +02:00
parent d47fcdd66d
commit 4018036980
3 changed files with 148 additions and 64 deletions

View File

@ -909,6 +909,10 @@ void t_csharp_generator::generate_csharp_struct_reader(ofstream& out, t_struct*
indent(out) << "public void Read (TProtocol iprot)" << endl;
scope_up(out);
out << indent() << "iprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
const vector<t_field*>& fields = tstruct->get_members();
vector<t_field*>::const_iterator f_iter;
@ -977,6 +981,12 @@ void t_csharp_generator::generate_csharp_struct_reader(ofstream& out, t_struct*
}
}
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "iprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent_down();
indent(out) << "}" << endl << endl;
@ -985,6 +995,10 @@ void t_csharp_generator::generate_csharp_struct_reader(ofstream& out, t_struct*
void t_csharp_generator::generate_csharp_struct_writer(ofstream& out, t_struct* tstruct) {
out << indent() << "public void Write(TProtocol oprot) {" << endl;
indent_up();
out << indent() << "oprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
string name = tstruct->get_name();
const vector<t_field*>& fields = tstruct->get_sorted_members();
@ -1030,8 +1044,14 @@ void t_csharp_generator::generate_csharp_struct_writer(ofstream& out, t_struct*
indent(out) << "oprot.WriteFieldStop();" << endl;
indent(out) << "oprot.WriteStructEnd();" << endl;
indent_down();
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "oprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent_down();
indent(out) << "}" << endl << endl;
}
@ -1039,6 +1059,10 @@ void t_csharp_generator::generate_csharp_struct_result_writer(ofstream& out, t_s
indent(out) << "public void Write(TProtocol oprot) {" << endl;
indent_up();
out << indent() << "oprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
string name = tstruct->get_name();
const vector<t_field*>& fields = tstruct->get_sorted_members();
vector<t_field*>::const_iterator f_iter;
@ -1092,6 +1116,12 @@ void t_csharp_generator::generate_csharp_struct_result_writer(ofstream& out, t_s
out << endl << indent() << "oprot.WriteFieldStop();" << endl << indent()
<< "oprot.WriteStructEnd();" << endl;
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "oprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent_down();
indent(out) << "}" << endl << endl;
@ -1249,6 +1279,11 @@ void t_csharp_generator::generate_csharp_union_class(std::ofstream& out,
indent(out) << "}" << endl;
indent(out) << "public override void Write(TProtocol oprot) {" << endl;
indent_up();
out << indent() << "oprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
indent(out) << "TStruct struc = new TStruct(\"" << tunion->get_name() << "\");" << endl;
indent(out) << "oprot.WriteStructBegin(struc);" << endl;
@ -1264,6 +1299,13 @@ void t_csharp_generator::generate_csharp_union_class(std::ofstream& out,
indent(out) << "oprot.WriteFieldStop();" << endl;
indent(out) << "oprot.WriteStructEnd();" << endl;
indent_down();
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "oprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent(out) << "}" << endl;
indent_down();
@ -1987,6 +2029,11 @@ void t_csharp_generator::generate_csharp_union_reader(std::ofstream& out, t_stru
indent(out) << "public static " << tunion->get_name() << " Read(TProtocol iprot)" << endl;
scope_up(out);
out << indent() << "iprot.IncrementRecursionDepth();" << endl;
out << indent() << "try" << endl;
scope_up(out);
indent(out) << tunion->get_name() << " retval;" << endl;
indent(out) << "iprot.ReadStructBegin();" << endl;
indent(out) << "TField field = iprot.ReadFieldBegin();" << endl;
@ -2036,13 +2083,16 @@ void t_csharp_generator::generate_csharp_union_reader(std::ofstream& out, t_stru
// end of else for TStop
scope_down(out);
indent(out) << "iprot.ReadStructEnd();" << endl;
indent(out) << "return retval;" << endl;
indent_down();
scope_down(out);
out << indent() << "finally" << endl;
scope_up(out);
out << indent() << "iprot.DecrementRecursionDepth();" << endl;
scope_down(out);
indent(out) << "}" << endl << endl;
}

View File

@ -29,11 +29,17 @@ namespace Thrift.Protocol
{
public abstract class TProtocol : IDisposable
{
private const int DEFAULT_RECURSION_DEPTH = 64;
protected TTransport trans;
protected int recursionLimit;
protected int recursionDepth;
protected TProtocol(TTransport trans)
{
this.trans = trans;
this.recursionLimit = DEFAULT_RECURSION_DEPTH;
this.recursionDepth = 0;
}
public TTransport Transport
@ -41,6 +47,25 @@ namespace Thrift.Protocol
get { return trans; }
}
public int RecursionLimit
{
get { return recursionLimit; }
set { recursionLimit = value; }
}
public void IncrementRecursionDepth()
{
if (recursionDepth < recursionLimit)
++recursionDepth;
else
throw new TProtocolException(TProtocolException.DEPTH_LIMIT, "Depth limit exceeded");
}
public void DecrementRecursionDepth()
{
--recursionDepth;
}
#region " IDisposable Support "
private bool _IsDisposed;

View File

@ -29,69 +29,78 @@ namespace Thrift.Protocol
{
public static void Skip(TProtocol prot, TType type)
{
switch (type)
prot.IncrementRecursionDepth();
try
{
case TType.Bool:
prot.ReadBool();
break;
case TType.Byte:
prot.ReadByte();
break;
case TType.I16:
prot.ReadI16();
break;
case TType.I32:
prot.ReadI32();
break;
case TType.I64:
prot.ReadI64();
break;
case TType.Double:
prot.ReadDouble();
break;
case TType.String:
// Don't try to decode the string, just skip it.
prot.ReadBinary();
break;
case TType.Struct:
prot.ReadStructBegin();
while (true)
{
TField field = prot.ReadFieldBegin();
if (field.Type == TType.Stop)
switch (type)
{
case TType.Bool:
prot.ReadBool();
break;
case TType.Byte:
prot.ReadByte();
break;
case TType.I16:
prot.ReadI16();
break;
case TType.I32:
prot.ReadI32();
break;
case TType.I64:
prot.ReadI64();
break;
case TType.Double:
prot.ReadDouble();
break;
case TType.String:
// Don't try to decode the string, just skip it.
prot.ReadBinary();
break;
case TType.Struct:
prot.ReadStructBegin();
while (true)
{
break;
TField field = prot.ReadFieldBegin();
if (field.Type == TType.Stop)
{
break;
}
Skip(prot, field.Type);
prot.ReadFieldEnd();
}
Skip(prot, field.Type);
prot.ReadFieldEnd();
}
prot.ReadStructEnd();
break;
case TType.Map:
TMap map = prot.ReadMapBegin();
for (int i = 0; i < map.Count; i++)
{
Skip(prot, map.KeyType);
Skip(prot, map.ValueType);
}
prot.ReadMapEnd();
break;
case TType.Set:
TSet set = prot.ReadSetBegin();
for (int i = 0; i < set.Count; i++)
{
Skip(prot, set.ElementType);
}
prot.ReadSetEnd();
break;
case TType.List:
TList list = prot.ReadListBegin();
for (int i = 0; i < list.Count; i++)
{
Skip(prot, list.ElementType);
}
prot.ReadListEnd();
break;
prot.ReadStructEnd();
break;
case TType.Map:
TMap map = prot.ReadMapBegin();
for (int i = 0; i < map.Count; i++)
{
Skip(prot, map.KeyType);
Skip(prot, map.ValueType);
}
prot.ReadMapEnd();
break;
case TType.Set:
TSet set = prot.ReadSetBegin();
for (int i = 0; i < set.Count; i++)
{
Skip(prot, set.ElementType);
}
prot.ReadSetEnd();
break;
case TType.List:
TList list = prot.ReadListBegin();
for (int i = 0; i < list.Count; i++)
{
Skip(prot, list.ElementType);
}
prot.ReadListEnd();
break;
}
}
finally
{
prot.DecrementRecursionDepth();
}
}
}