THRIFT-5238 GetHashCode can throw NullReferenceException

Client: netstd
Patch: Jens Geyer

This closes #2187
This commit is contained in:
Jens Geyer 2020-06-24 23:51:01 +02:00
parent 283410126c
commit 6e16c2bc54
3 changed files with 158 additions and 72 deletions

View File

@ -1173,17 +1173,8 @@ void t_netstd_generator::generate_netstd_deepcopy_method(ostream& out, t_struct*
t_type* ttype = (*m_iter)->get_type();
string copy_op = get_deep_copy_method_call(ttype, needs_typecast);
bool have_indent = false;
if (!field_is_required(*m_iter)) {
out << indent() << "if( this.__isset." << normalize_name((*m_iter)->get_name()) << ")" << endl;
indent_up();
have_indent = true;
}
else if( type_can_be_null(ttype)) {
out << indent() << "if( this." << prop_name(*m_iter) << " != null)" << endl;
indent_up();
have_indent = true;
}
bool is_required = field_is_required(*m_iter);
generate_null_check_begin( out, *m_iter);
out << indent() << tmp_instance << "." << prop_name(*m_iter) << " = ";
if( needs_typecast) {
@ -1191,8 +1182,10 @@ void t_netstd_generator::generate_netstd_deepcopy_method(ostream& out, t_struct*
}
out << "this." << prop_name(*m_iter) << copy_op << ";" << endl;
if (have_indent) {
indent_down();
generate_null_check_end( out, *m_iter);
if( !is_required) {
out << indent() << tmp_instance << ".__isset." << normalize_name((*m_iter)->get_name())
<< " = this.__isset." << normalize_name((*m_iter)->get_name()) << ";" << endl;
}
}
@ -1306,6 +1299,44 @@ void t_netstd_generator::generate_netstd_struct_reader(ostream& out, t_struct* t
out << indent() << "}" << endl << endl;
}
void t_netstd_generator::generate_null_check_begin(ostream& out, t_field* tfield) {
bool is_required = field_is_required(tfield);
bool null_allowed = type_can_be_null(tfield->get_type());
if( null_allowed || (!is_required)) {
bool first = true;
out << indent() << "if(";
if( null_allowed) {
out << "(" << prop_name(tfield) << " != null)";
first = false;
}
if( !is_required) {
if( !first) {
out << " && ";
}
out << "__isset." << normalize_name(tfield->get_name());
}
out << ")" << endl
<< indent() << "{" << endl;
indent_up();
}
}
void t_netstd_generator::generate_null_check_end(ostream& out, t_field* tfield) {
bool is_required = field_is_required(tfield);
bool null_allowed = type_can_be_null(tfield->get_type());
if( null_allowed || (!is_required)) {
indent_down();
out << indent() << "}" << endl;
}
}
void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* tstruct)
{
out << indent() << "public async Task WriteAsync(TProtocol oprot, CancellationToken cancellationToken)" << endl
@ -1329,23 +1360,7 @@ void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* t
out << indent() << "var field = new TField();" << endl;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter)
{
bool is_required = field_is_required(*f_iter);
if (!is_required)
{
bool null_allowed = type_can_be_null((*f_iter)->get_type());
if (null_allowed)
{
out << indent() << "if (" << prop_name(*f_iter) << " != null && __isset." << normalize_name((*f_iter)->get_name()) << ")" << endl
<< indent() << "{" << endl;
indent_up();
}
else
{
out << indent() << "if (__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl
<< indent() << "{" << endl;
indent_up();
}
}
generate_null_check_begin( out, *f_iter);
out << indent() << "field.Name = \"" << (*f_iter)->get_name() << "\";" << endl
<< indent() << "field.Type = " << type_to_enum((*f_iter)->get_type()) << ";" << endl
<< indent() << "field.ID = " << (*f_iter)->get_key() << ";" << endl
@ -1354,11 +1369,7 @@ void t_netstd_generator::generate_netstd_struct_writer(ostream& out, t_struct* t
generate_serialize_field(out, *f_iter);
out << indent() << "await oprot.WriteFieldEndAsync(cancellationToken);" << endl;
if (!is_required)
{
indent_down();
out << indent() << "}" << endl;
}
generate_null_check_end(out, *f_iter);
}
}
@ -1482,23 +1493,8 @@ void t_netstd_generator::generate_netstd_struct_tostring(ostream& out, t_struct*
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter)
{
bool is_required = field_is_required((*f_iter));
if (!is_required)
{
bool null_allowed = type_can_be_null((*f_iter)->get_type());
if (null_allowed)
{
out << indent() << "if (" << prop_name((*f_iter)) << " != null && __isset." << normalize_name((*f_iter)->get_name()) << ")" << endl
<< indent() << "{" << endl;
indent_up();
}
else
{
out << indent() << "if (__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl
<< indent() << "{" << endl;
indent_up();
}
}
bool is_required = field_is_required(*f_iter);
generate_null_check_begin(out, *f_iter);
if (useFirstFlag && (!had_required))
{
@ -1512,13 +1508,8 @@ void t_netstd_generator::generate_netstd_struct_tostring(ostream& out, t_struct*
out << indent() << prop_name(*f_iter) << ".ToString(sb);" << endl;
if (!is_required)
{
indent_down();
out << indent() << "}" << endl;
}
else
{
generate_null_check_end(out, *f_iter);
if (is_required) {
had_required = true; // now __count must be > 0, so we don't need to check it anymore
}
}
@ -1859,26 +1850,18 @@ void t_netstd_generator::generate_netstd_struct_hashcode(ostream& out, t_struct*
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter)
{
t_type* ttype = (*f_iter)->get_type();
if (!field_is_required((*f_iter)))
{
out << indent() << "if(__isset." << normalize_name((*f_iter)->get_name()) << ")" << endl;
indent_up();
}
generate_null_check_begin(out, *f_iter);
out << indent() << "hashcode = (hashcode * 397) + ";
if (ttype->is_container())
{
if (ttype->is_container()) {
out << "TCollections.GetHashCode(" << prop_name((*f_iter)) << ")";
}
else
{
else {
out << prop_name((*f_iter)) << ".GetHashCode()";
}
out << ";" << endl;
if (!field_is_required((*f_iter)))
{
indent_down();
}
generate_null_check_end(out, *f_iter);
}
indent_down();

View File

@ -171,4 +171,6 @@ private:
void collect_extensions_types(t_type* ttype);
void generate_extensions(ostream& out, map<string, t_type*> types);
void reset_indent();
void generate_null_check_begin(ostream& out, t_field* tfield);
void generate_null_check_end(ostream& out, t_field* tfield);
};

View File

@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation(ASF) under one
// or more contributor license agreements.See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Microsoft.VisualStudio.TestPlatform.ObjectModel;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using OptReqDefTest;
using Thrift.Collections;
namespace Thrift.Tests.DataModel
{
// ReSharper disable once InconsistentNaming
[TestClass]
public class Thrift_5238
{
private void CheckInstance(RaceDetails instance)
{
// object
Assert.IsTrue(instance.__isset.def_nested);
Assert.IsTrue(instance.__isset.opt_nested);
Assert.IsNull(instance.Def_nested);
Assert.IsNull(instance.Opt_nested);
// string
Assert.IsTrue(instance.__isset.def_four);
Assert.IsTrue(instance.__isset.opt_four);
Assert.IsNull(instance.Req_four);
Assert.IsNull(instance.Def_four);
Assert.IsNull(instance.Opt_four);
// byte[]
Assert.IsTrue(instance.__isset.def_five);
Assert.IsTrue(instance.__isset.opt_five);
Assert.IsNull(instance.Req_five);
Assert.IsNull(instance.Def_five);
Assert.IsNull(instance.Opt_five);
// list<>
Assert.IsTrue(instance.__isset.def_six);
Assert.IsTrue(instance.__isset.opt_six);
Assert.IsNull(instance.Req_six);
Assert.IsNull(instance.Opt_six);
Assert.IsNull(instance.Def_six);
}
[TestMethod]
public void Thrift_5238_ProperNullChecks()
{
var instance = new OptReqDefTest.RaceDetails();
// object
instance.Def_nested = null;
instance.Opt_nested = null;
// string
instance.Req_four = null;
instance.Def_four = null;
instance.Opt_four = null;
// byte[]
instance.Req_five = null;
instance.Def_five = null;
instance.Opt_five = null;
// list<>
instance.Req_six = null;
instance.Opt_six = null;
instance.Def_six = null;
// test the setup
CheckInstance(instance);
// validate proper null checks , any of these throws if not
instance.ToString();
instance.GetHashCode();
// validate proper null checks , any of these throws if not
var copy = instance.DeepCopy();
CheckInstance(copy);
}
}
}