From a24bcd39a50e3472b5f1bf9fb5b8526fccf7911e Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Tue, 10 May 2022 19:11:07 +0800 Subject: [PATCH] THRIFT-5582: Improve TProtocol.java for Kotlin generator (#2603) * Improve TProtocol.java through the use of interfaces * Reduce Kotlin extensions usage --- .../src/thrift/generate/t_kotlin_generator.cc | 146 +------- .../apache/thrift/TMultiplexedProcessor.java | 1 + .../thrift/protocol/TBinaryProtocol.java | 1 + .../thrift/protocol/TCompactProtocol.java | 36 ++ .../apache/thrift/protocol/TJSONProtocol.java | 1 + .../org/apache/thrift/protocol/TList.java | 8 + .../java/org/apache/thrift/protocol/TMap.java | 12 + .../org/apache/thrift/protocol/TMessage.java | 12 + .../org/apache/thrift/protocol/TProtocol.java | 339 ++++++++++++++---- .../thrift/protocol/TProtocolDecorator.java | 42 +++ .../thrift/protocol/TProtocolFactory.java | 2 +- .../apache/thrift/protocol/TProtocolUtil.java | 3 + .../apache/thrift/protocol/TReadProtocol.java | 47 +++ .../java/org/apache/thrift/protocol/TSet.java | 8 + .../thrift/protocol/TSimpleJSONProtocol.java | 1 + .../org/apache/thrift/protocol/TStruct.java | 4 + .../thrift/protocol/TTupleProtocol.java | 4 + .../thrift/protocol/TWriteProtocol.java | 49 +++ 18 files changed, 517 insertions(+), 199 deletions(-) create mode 100644 lib/java/src/main/java/org/apache/thrift/protocol/TReadProtocol.java create mode 100644 lib/java/src/main/java/org/apache/thrift/protocol/TWriteProtocol.java diff --git a/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc b/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc index 28477ba9a..3a2afe6d5 100644 --- a/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_kotlin_generator.cc @@ -94,7 +94,6 @@ public: private: std::string package_name_; std::string package_dir_; - ofstream_with_content_based_conditional_update f_extensions_; ofstream_with_content_based_conditional_update f_types_; std::string kotlin_package(); @@ -122,11 +121,6 @@ private: void generate_kotlin_struct(t_struct* tstruct, bool is_exception); - void generate_kotlin_ext_definition(std::ostream& out, std::string type_name); - void generate_kotlin_ext_field_definition(std::ostream& out); - void generate_kotlin_ext_map_definition(std::ostream& out); - void generate_kotlin_ext_container_definition(std::ostream& out, std::string type_name); - void generate_service_interface(t_service* tservice); void generate_service_client(t_service* tservice); void generate_client_call(std::ostream& out, t_service* tservice, t_function* tfunc); @@ -207,17 +201,6 @@ void t_kotlin_generator::init_generator() { string f_types_name = package_dir_ + "/" + program_->get_name() + "Constants.kt"; f_types_.open(f_types_name); f_types_ << autogen_comment() << kotlin_package(); - - string f_extensions_name = package_dir_ + "/TProtocolExt.kt"; - f_extensions_.open(f_extensions_name); - f_extensions_ << autogen_comment() << kotlin_package(); - - generate_kotlin_ext_definition(f_extensions_, "Message"); - generate_kotlin_ext_definition(f_extensions_, "Struct"); - generate_kotlin_ext_field_definition(f_extensions_); - generate_kotlin_ext_map_definition(f_extensions_); - generate_kotlin_ext_container_definition(f_extensions_, "Set"); - generate_kotlin_ext_container_definition(f_extensions_, "List"); } /** @@ -226,105 +209,6 @@ void t_kotlin_generator::init_generator() { void t_kotlin_generator::close_generator() { f_types_ << endl; f_types_.close(); - f_extensions_ << endl; - f_extensions_.close(); -} - -void t_kotlin_generator::generate_kotlin_ext_definition(std::ostream& out, std::string type_name) { - out << "internal inline fun org.apache.thrift.protocol.TProtocol.write" << type_name - << "(marker: " - "org.apache.thrift.protocol.T" - << type_name << ", action: () -> Unit) {" << endl; - indent_up(); - indent(out) << "write" << type_name << "Begin(marker)" << endl; - indent(out) << "try { action() }" << endl; - indent(out) << "finally { write" << type_name << "End() }" << endl; - scope_down(out); - out << endl; - - out << "internal inline fun org.apache.thrift.protocol.TProtocol.read" << type_name - << "(action: org.apache.thrift.protocol.T" << type_name << ".() -> R): R {" << endl; - indent_up(); - indent(out) << "val marker = read" << type_name << "Begin()" << endl; - indent(out) << "try { return action(marker) }" << endl; - indent(out) << "finally { read" << type_name << "End() }" << endl; - scope_down(out); - out << endl; -} - -void t_kotlin_generator::generate_kotlin_ext_field_definition(std::ostream& out) { - out << "internal inline fun org.apache.thrift.protocol.TProtocol.writeField(marker: " - "org.apache.thrift.protocol.TField, action: () -> Unit) {" - << endl; - indent_up(); - indent(out) << "writeFieldBegin(marker)" << endl; - indent(out) << "try { action() }" << endl; - indent(out) << "finally { writeFieldEnd() }" << endl; - scope_down(out); - out << endl; - - out << "internal inline fun org.apache.thrift.protocol.TProtocol.readField(action: " - "org.apache.thrift.protocol.TField.() -> kotlin.Unit): kotlin.Boolean {" - << endl; - indent_up(); - indent(out) << "val marker = readFieldBegin()" << endl; - indent(out) << "if (marker.type == org.apache.thrift.protocol.TType.STOP) { return true }" - << endl; - indent(out) << "try {" << endl; - indent_up(); - indent(out) << "action(marker)" << endl; - indent(out) << "return false" << endl; - indent_down(); - indent(out) << "} finally { readFieldEnd() }" << endl; - scope_down(out); - out << endl; -} - -void t_kotlin_generator::generate_kotlin_ext_map_definition(std::ostream& out) { - out << "internal inline fun org.apache.thrift.protocol.TProtocol.writeMap(keyType: " - "kotlin.Byte, valueType: kotlin.Byte, map: Map, action: (Map.Entry) -> " - "Unit) {" - << endl; - indent_up(); - indent(out) << "writeMapBegin(org.apache.thrift.protocol.TMap(keyType, valueType, map.size))" - << endl; - indent(out) << "map.forEach { action(it) }" << endl; - indent(out) << "writeMapEnd()" << endl; - scope_down(out); - out << endl; - out << "internal inline fun org.apache.thrift.protocol.TProtocol.readMap(action: " - "org.apache.thrift.protocol.TMap.() -> R): R {" - << endl; - indent_up(); - indent(out) << "val marker = readMapBegin()" << endl; - indent(out) << "val r = action(marker)" << endl; - indent(out) << "readMapEnd()" << endl; - indent(out) << "return r" << endl; - scope_down(out); - out << endl; -} - -void t_kotlin_generator::generate_kotlin_ext_container_definition(std::ostream& out, - std::string type_name) { - out << "internal inline fun org.apache.thrift.protocol.TProtocol.write" << type_name - << "(elemType: kotlin.Byte, container: " << type_name << ", action: (T) -> Unit) {" - << endl; - indent_up(); - indent(out) << "write" << type_name << "Begin(org.apache.thrift.protocol.T" << type_name - << "(elemType, container.size))" << endl; - indent(out) << "container.forEach { action(it) }" << endl; - indent(out) << "write" << type_name << "End()" << endl; - scope_down(out); - out << endl; - out << "internal inline fun org.apache.thrift.protocol.TProtocol.read" << type_name - << "(action: org.apache.thrift.protocol.T" << type_name << ".() -> R): R {" << endl; - indent_up(); - indent(out) << "val marker = read" << type_name << "Begin()" << endl; - indent(out) << "val r = action(marker)" << endl; - indent(out) << "read" << type_name << "End()" << endl; - indent(out) << "return r" << endl; - scope_down(out); - out << endl; } /** @@ -972,17 +856,18 @@ void t_kotlin_generator::generate_struct_standard_scheme_read(std::ostream& out, indent_up(); { indent(out) << "val skipNext = { " - "org.apache.thrift.protocol.TProtocolUtil.skip(iproto, type) }" + "org.apache.thrift.protocol.TProtocolUtil.skip(iproto, it.type) }" << endl; - indent(out) << "when (id.toInt()) {" << endl; + indent(out) << "when (it.id.toInt()) {" << endl; indent_up(); { for (auto& field : tstruct->get_members()) { indent(out) << field->get_key() << " -> {" << endl; indent_up(); { - indent(out) << "if (type == " << type_to_enum(field->get_type()) << ") {" << endl; + indent(out) << "if (it.type == " << type_to_enum(field->get_type()) << ") {" + << endl; indent_up(); generate_deserialize_field(out, field, "struct."); indent_down(); @@ -1027,7 +912,8 @@ void t_kotlin_generator::generate_struct_standard_scheme_write(std::ostream& out for (auto& field : tstruct->get_members()) { auto is_required = field->get_req() == t_field::T_REQUIRED; indent(out) << "struct." << kotlin_safe_name(field->get_name()) - << (is_required ? "" : "?") << ".let {" << endl; + << (is_required ? "" : "?") << ".let { " + << kotlin_safe_name(field->get_name()) << " ->" << endl; indent_up(); { indent(out) << "writeField(" << constant_name(field->get_name()) << "_FIELD_DESC) {" @@ -1391,7 +1277,7 @@ void t_kotlin_generator::generate_serialize_field(ostream& out, t_field* tfield) throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + tfield->get_name(); } indent(out); - generate_serialize_value(out, type); + generate_serialize_value(out, type, kotlin_safe_name(tfield->get_name())); out << endl; } @@ -1461,9 +1347,9 @@ void t_kotlin_generator::generate_serialize_container(ostream& out, t_type* ttyp */ void t_kotlin_generator::generate_deserialize_container(ostream& out, t_type* ttype) { if (ttype->is_map()) { - out << "readMap {" << endl; + out << "readMap { tmap ->" << endl; indent_up(); - indent(out) << "kotlin.collections.List(size) {" << endl; + indent(out) << "kotlin.collections.List(tmap.size) {" << endl; indent_up(); indent(out); generate_deserialize_value(out, ((t_map*)ttype)->get_key_type()); @@ -1475,9 +1361,9 @@ void t_kotlin_generator::generate_deserialize_container(ostream& out, t_type* tt indent_down(); indent(out) << "}"; } else if (ttype->is_set()) { - out << "readSet {" << endl; + out << "readSet { tset ->" << endl; indent_up(); - indent(out) << "kotlin.collections.List(size) {" << endl; + indent(out) << "kotlin.collections.List(tset.size) {" << endl; indent_up(); indent(out); generate_deserialize_value(out, ((t_set*)ttype)->get_elem_type()); @@ -1487,9 +1373,9 @@ void t_kotlin_generator::generate_deserialize_container(ostream& out, t_type* tt indent_down(); indent(out) << "}"; } else if (ttype->is_list()) { - out << "readList {" << endl; + out << "readList { tlist ->" << endl; indent_up(); - indent(out) << "kotlin.collections.List(size) {" << endl; + indent(out) << "kotlin.collections.List(tlist.size) {" << endl; indent_up(); indent(out); generate_deserialize_value(out, ((t_list*)ttype)->get_elem_type()); @@ -1664,19 +1550,19 @@ void t_kotlin_generator::generate_client_call(std::ostream& out, indent(out) << "return protocol.readMessage {" << endl; indent_up(); { - indent(out) << "if (type == org.apache.thrift.protocol.TMessageType.EXCEPTION) {" << endl; + indent(out) << "if (it.type == org.apache.thrift.protocol.TMessageType.EXCEPTION) {" << endl; indent_up(); indent(out) << "val ex = org.apache.thrift.TApplicationException().apply { read(protocol) }" << endl; indent(out) << "throw ex" << endl; scope_down(out); - indent(out) << "if (seqid != seqId) {" << endl; + indent(out) << "if (it.seqid != seqId) {" << endl; indent_up(); indent(out) << "throw org.apache.thrift.TApplicationException(" << endl; indent_up(); indent(out) << "org.apache.thrift.TApplicationException.BAD_SEQUENCE_ID," << endl; indent(out) << "\"" << funname - << " failed: out of sequence response: expected $seqId but got ${seqid}\"" + << " failed: out of sequence response: expected $seqId but got ${it.seqid}\"" << endl; indent_down(); indent(out) << ")" << endl; diff --git a/lib/java/src/main/java/org/apache/thrift/TMultiplexedProcessor.java b/lib/java/src/main/java/org/apache/thrift/TMultiplexedProcessor.java index 28ae13880..60c79441a 100644 --- a/lib/java/src/main/java/org/apache/thrift/TMultiplexedProcessor.java +++ b/lib/java/src/main/java/org/apache/thrift/TMultiplexedProcessor.java @@ -94,6 +94,7 @@ public class TMultiplexedProcessor implements TProcessor { * called {@link #registerProcessor(String, TProcessor) registerProcessor} during * initialization, right? :) */ + @Override public void process(TProtocol iprot, TProtocol oprot) throws TException { /* Use the actual underlying protocol (e.g. TBinaryProtocol) to read the diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TBinaryProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TBinaryProtocol.java index 3b93dbf98..34e7517fc 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TBinaryProtocol.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TBinaryProtocol.java @@ -458,6 +458,7 @@ public class TBinaryProtocol extends TProtocol { } /** Return the minimum number of bytes a type will consume on the wire */ + @Override public int getMinSerializedSize(byte type) throws TTransportException { switch (type) { case 0: diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TCompactProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TCompactProtocol.java index 8adb25e69..fffd687b3 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TCompactProtocol.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TCompactProtocol.java @@ -77,6 +77,7 @@ public class TCompactProtocol extends TProtocol { this.stringLengthLimit_ = stringLengthLimit; } + @Override public TProtocol getProtocol(TTransport trans) { return new TCompactProtocol(trans, stringLengthLimit_, containerLengthLimit_); } @@ -213,6 +214,7 @@ public class TCompactProtocol extends TProtocol { * Write a struct end. This doesn't actually put anything on the wire. We use this as an * opportunity to pop the last field from the current struct off of the field stack. */ + @Override public void writeStructEnd() throws TException { lastFieldId_ = lastField_.pop(); } @@ -222,6 +224,7 @@ public class TCompactProtocol extends TProtocol { * current field id and the last one is small (< 15), then the field id will be encoded in the * 4 MSB as a delta. Otherwise, the field id will follow the type header as a zigzag varint. */ + @Override public void writeFieldBegin(TField field) throws TException { if (field.type == TType.BOOL) { // we want to possibly include the value, so we'll wait. @@ -256,6 +259,7 @@ public class TCompactProtocol extends TProtocol { } /** Write the STOP symbol so we know there are no more fields in this struct. */ + @Override public void writeFieldStop() throws TException { writeByteDirect(TType.STOP); } @@ -264,6 +268,7 @@ public class TCompactProtocol extends TProtocol { * Write a map header. If the map is empty, omit the key and value type headers, as we don't need * any additional information to skip it. */ + @Override public void writeMapBegin(TMap map) throws TException { if (map.size == 0) { writeByteDirect(0); @@ -274,11 +279,13 @@ public class TCompactProtocol extends TProtocol { } /** Write a list header. */ + @Override public void writeListBegin(TList list) throws TException { writeCollectionBegin(list.elemType, list.size); } /** Write a set header. */ + @Override public void writeSetBegin(TSet set) throws TException { writeCollectionBegin(set.elemType, set.size); } @@ -288,6 +295,7 @@ public class TCompactProtocol extends TProtocol { * header info isn't written yet. If so, decide what the right type header is for the value and * then write the field header. Otherwise, write a single byte. */ + @Override public void writeBool(boolean b) throws TException { if (booleanField_ != null) { // we haven't written the field header yet @@ -300,32 +308,38 @@ public class TCompactProtocol extends TProtocol { } /** Write a byte. Nothing to see here! */ + @Override public void writeByte(byte b) throws TException { writeByteDirect(b); } /** Write an I16 as a zigzag varint. */ + @Override public void writeI16(short i16) throws TException { writeVarint32(intToZigZag(i16)); } /** Write an i32 as a zigzag varint. */ + @Override public void writeI32(int i32) throws TException { writeVarint32(intToZigZag(i32)); } /** Write an i64 as a zigzag varint. */ + @Override public void writeI64(long i64) throws TException { writeVarint64(longToZigzag(i64)); } /** Write a double to the wire as 8 bytes. */ + @Override public void writeDouble(double dub) throws TException { fixedLongToBytes(Double.doubleToLongBits(dub), temp, 0); trans_.write(temp, 0, 8); } /** Write a string to the wire with a varint size preceding. */ + @Override public void writeString(String str) throws TException { byte[] bytes = str.getBytes(StandardCharsets.UTF_8); writeVarint32(bytes.length); @@ -333,6 +347,7 @@ public class TCompactProtocol extends TProtocol { } /** Write a byte array, using a varint for the size. */ + @Override public void writeBinary(ByteBuffer bin) throws TException { ByteBuffer bb = bin.asReadOnlyBuffer(); writeVarint32(bb.remaining()); @@ -344,14 +359,19 @@ public class TCompactProtocol extends TProtocol { // output or purpose. // + @Override public void writeMessageEnd() throws TException {} + @Override public void writeMapEnd() throws TException {} + @Override public void writeListEnd() throws TException {} + @Override public void writeSetEnd() throws TException {} + @Override public void writeFieldEnd() throws TException {} // @@ -454,6 +474,7 @@ public class TCompactProtocol extends TProtocol { // /** Read a message header. */ + @Override public TMessage readMessageBegin() throws TException { byte protocolId = readByte(); if (protocolId != PROTOCOL_ID) { @@ -478,6 +499,7 @@ public class TCompactProtocol extends TProtocol { * Read a struct begin. There's nothing on the wire for this, but it is our opportunity to push a * new struct begin marker onto the field stack. */ + @Override public TStruct readStructBegin() throws TException { lastField_.push(lastFieldId_); lastFieldId_ = 0; @@ -488,12 +510,14 @@ public class TCompactProtocol extends TProtocol { * Doesn't actually consume any wire data, just removes the last field for this struct from the * field stack. */ + @Override public void readStructEnd() throws TException { // consume the last field we read off the wire. lastFieldId_ = lastField_.pop(); } /** Read a field header off the wire. */ + @Override public TField readFieldBegin() throws TException { byte type = readByte(); @@ -531,6 +555,7 @@ public class TCompactProtocol extends TProtocol { * Read a map header off the wire. If the size is zero, skip reading the key and value type. This * means that 0-length maps will yield TMaps without the "correct" types. */ + @Override public TMap readMapBegin() throws TException { int size = readVarint32(); checkContainerReadLength(size); @@ -549,6 +574,7 @@ public class TCompactProtocol extends TProtocol { * element type header. If it's a longer list, the 4 MSB of the element type header will be 0xF, * and a varint will follow with the true size. */ + @Override public TList readListBegin() throws TException { byte size_and_type = readByte(); int size = (size_and_type >> 4) & 0x0f; @@ -566,6 +592,7 @@ public class TCompactProtocol extends TProtocol { * element type header. If it's a longer set, the 4 MSB of the element type header will be 0xF, * and a varint will follow with the true size. */ + @Override public TSet readSetBegin() throws TException { return new TSet(readListBegin()); } @@ -574,6 +601,7 @@ public class TCompactProtocol extends TProtocol { * Read a boolean off the wire. If this is a boolean field, the value should already have been * read during readFieldBegin, so we'll just consume the pre-stored value. Otherwise, read a byte. */ + @Override public boolean readBool() throws TException { if (boolValue_ != null) { boolean result = boolValue_; @@ -584,6 +612,7 @@ public class TCompactProtocol extends TProtocol { } /** Read a single byte off the wire. Nothing interesting here. */ + @Override public byte readByte() throws TException { byte b; if (trans_.getBytesRemainingInBuffer() > 0) { @@ -597,27 +626,32 @@ public class TCompactProtocol extends TProtocol { } /** Read an i16 from the wire as a zigzag varint. */ + @Override public short readI16() throws TException { return (short) zigzagToInt(readVarint32()); } /** Read an i32 from the wire as a zigzag varint. */ + @Override public int readI32() throws TException { return zigzagToInt(readVarint32()); } /** Read an i64 from the wire as a zigzag varint. */ + @Override public long readI64() throws TException { return zigzagToLong(readVarint64()); } /** No magic here - just read a double off the wire. */ + @Override public double readDouble() throws TException { trans_.readAll(temp, 0, 8); return Double.longBitsToDouble(bytesToLong(temp)); } /** Reads a byte[] (via readBinary), and then UTF-8 decodes it. */ + @Override public String readString() throws TException { int length = readVarint32(); checkStringReadLength(length); @@ -639,6 +673,7 @@ public class TCompactProtocol extends TProtocol { } /** Read a ByteBuffer from the wire. */ + @Override public ByteBuffer readBinary() throws TException { int length = readVarint32(); if (length == 0) { @@ -848,6 +883,7 @@ public class TCompactProtocol extends TProtocol { } /** Return the minimum number of bytes a type will consume on the wire */ + @Override public int getMinSerializedSize(byte type) throws TTransportException { switch (type) { case 0: diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TJSONProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TJSONProtocol.java index 53d0d715c..dd24ca4fc 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TJSONProtocol.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TJSONProtocol.java @@ -952,6 +952,7 @@ public class TJSONProtocol extends TProtocol { } /** Return the minimum number of bytes a type will consume on the wire */ + @Override public int getMinSerializedSize(byte type) throws TTransportException { switch (type) { case 0: diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TList.java b/lib/java/src/main/java/org/apache/thrift/protocol/TList.java index 25368916d..9858b0c7b 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TList.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TList.java @@ -32,4 +32,12 @@ public final class TList { public final byte elemType; public final int size; + + public byte getElemType() { + return elemType; + } + + public int getSize() { + return size; + } } diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TMap.java b/lib/java/src/main/java/org/apache/thrift/protocol/TMap.java index 72dace2cd..a5d4fbdf2 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TMap.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TMap.java @@ -34,4 +34,16 @@ public final class TMap { public final byte keyType; public final byte valueType; public final int size; + + public byte getKeyType() { + return keyType; + } + + public byte getValueType() { + return valueType; + } + + public int getSize() { + return size; + } } diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TMessage.java b/lib/java/src/main/java/org/apache/thrift/protocol/TMessage.java index a5468d737..69c9dd1f2 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TMessage.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TMessage.java @@ -35,6 +35,18 @@ public final class TMessage { public final byte type; public final int seqid; + public String getName() { + return name; + } + + public byte getType() { + return type; + } + + public int getSeqid() { + return seqid; + } + @Override public String toString() { return ""; diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java index a55d9f7ed..22333d57c 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocol.java @@ -19,7 +19,13 @@ package org.apache.thrift.protocol; -import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.IntFunction; import org.apache.thrift.TException; import org.apache.thrift.partial.TFieldData; import org.apache.thrift.scheme.IScheme; @@ -27,7 +33,7 @@ import org.apache.thrift.scheme.StandardScheme; import org.apache.thrift.transport.TTransport; /** Protocol interface definition. */ -public abstract class TProtocol { +public abstract class TProtocol implements TWriteProtocol, TReadProtocol { /** Prevent direct instantiation */ @SuppressWarnings("unused") @@ -52,106 +58,303 @@ public abstract class TProtocol { } protected void checkReadBytesAvailable(TList list) throws TException { - trans_.checkReadBytesAvailable(list.size * getMinSerializedSize(list.elemType)); + long size = list.getSize(); + trans_.checkReadBytesAvailable(size * getMinSerializedSize(list.elemType)); } protected void checkReadBytesAvailable(TSet set) throws TException { - trans_.checkReadBytesAvailable(set.size * getMinSerializedSize(set.elemType)); + long size = set.getSize(); + trans_.checkReadBytesAvailable(size * getMinSerializedSize(set.elemType)); } /** - * Return + * Return min serialized size in bytes * * @param type Returns the minimum amount of bytes needed to store the smallest possible instance * of TType. - * @return - * @throws TException + * @return min serialized size + * @throws TException when error happens */ public abstract int getMinSerializedSize(byte type) throws TException; - /** Writing methods. */ - public abstract void writeMessageBegin(TMessage message) throws TException; + public interface WriteCallback { + void call(T e) throws TException; + } - public abstract void writeMessageEnd() throws TException; + public interface ReadCallback { + R accept(T t) throws TException; + } - public abstract void writeStructBegin(TStruct struct) throws TException; + public interface ReadCollectionCallback { + R call() throws TException; + } - public abstract void writeStructEnd() throws TException; + public interface ReadMapEntryCallback { + K getKey() throws TException; - public abstract void writeFieldBegin(TField field) throws TException; + V getValue() throws TException; + } - public abstract void writeFieldEnd() throws TException; + public final void writeSet(byte elementType, Set set, WriteCallback callback) + throws TException { + writeSetBegin(new TSet(elementType, set.size())); + for (T t : set) { + callback.call(t); + } + writeSetEnd(); + } - public abstract void writeFieldStop() throws TException; + public final void writeList(byte elementType, List list, WriteCallback callback) + throws TException { + writeListBegin(new TList(elementType, list.size())); + for (T t : list) { + callback.call(t); + } + writeListEnd(); + } - public abstract void writeMapBegin(TMap map) throws TException; + public final void writeMap( + byte keyType, byte valueType, Map map, WriteCallback> callback) + throws TException { + writeMapBegin(new TMap(keyType, valueType, map.size())); + for (Map.Entry entry : map.entrySet()) { + callback.call(entry); + } + writeMapEnd(); + } - public abstract void writeMapEnd() throws TException; + public final void writeField(TField field, WriteCallback callback) throws TException { + writeFieldBegin(field); + callback.call(null); + writeFieldEnd(); + } - public abstract void writeListBegin(TList list) throws TException; + public final void writeStruct(TStruct struct, WriteCallback callback) throws TException { + writeStructBegin(struct); + callback.call(null); + writeStructEnd(); + } - public abstract void writeListEnd() throws TException; + public final void writeMessage(TMessage message, WriteCallback callback) throws TException { + writeMessageBegin(message); + callback.call(null); + writeMessageEnd(); + } - public abstract void writeSetBegin(TSet set) throws TException; + /** + * read a message by delegating to a callback, handles {@link #readMessageBegin() begin} and + * {@link #readMessageEnd() end} automatically. + * + * @param callback callback for actual reading + * @param result message type + * @return the message read + * @throws TException when any sub-operation failed + */ + public final T readMessage(ReadCallback callback) throws TException { + TMessage tMessage = readMessageBegin(); + T t = callback.accept(tMessage); + readMessageEnd(); + return t; + } - public abstract void writeSetEnd() throws TException; + /** + * read a struct by delegating to a callback, handles {@link #readStructBegin() begin} and {@link + * #readStructEnd() end} automatically. + * + * @param callback callback for actual reading + * @param result struct type + * @return the struct read + * @throws TException when any sub-operation failed + */ + public final T readStruct(ReadCallback callback) throws TException { + TStruct tStruct = readStructBegin(); + T t = callback.accept(tStruct); + readStructEnd(); + return t; + } - public abstract void writeBool(boolean b) throws TException; + /** + * read a field by delegating to a callback, handles {@link #readFieldBegin() begin} and {@link + * #readFieldEnd() end} automatically, and returns whether the {@link TType#STOP stop signal} was + * encountered. Because the value is not returned, you (the compiler generated code in most cases) + * are expected to set the field yourself within the callback. + * + * @param callback callback for reading a field + * @param result field type + * @return true if a stop signal was encountered, false otherwise + * @throws Exception when any sub-operation failed + */ + public final boolean readField(ReadCallback callback) throws Exception { + TField tField = readFieldBegin(); + if (tField.type == org.apache.thrift.protocol.TType.STOP) { + return true; + } + callback.accept(tField); + readFieldEnd(); + return false; + } - public abstract void writeByte(byte b) throws TException; + /** + * read a {@link Map} of elements by delegating to the callback, handles {@link #readMapBegin() + * begin} and {@link #readMapEnd() end} automatically. + * + * @param callback callback for reading the map + * @param result map type + * @return the map read + * @throws TException when any sub-operation fails + */ + public final > T readMap(ReadCallback callback) throws TException { + TMap tMap = readMapBegin(); + T t = callback.accept(tMap); + readMapEnd(); + return t; + } - public abstract void writeI16(short i16) throws TException; + /** + * read a {@link Map} of elements by delegating key and value reading to the callback, handles + * {@link #readMapBegin() begin} and {@link #readMapEnd() end} automatically. + * + * @param callback callback for reading keys and values, calls to {@link + * ReadMapEntryCallback#getKey()} and {@link ReadMapEntryCallback#getValue()} will be in + * alternating orders, i.e. k1, v1, k2, v2, .., k_n, v_n + * @param key type + * @param value type + * @return the map read + * @throws TException when any sub-operation fails + */ + public final Map readMap(ReadMapEntryCallback callback) throws TException { + return readMap(callback, HashMap::new); + } - public abstract void writeI32(int i32) throws TException; + /** + * read a {@link Map} of elements by delegating key and value reading to the callback, handles + * {@link #readMapBegin() begin} and {@link #readMapEnd() end} automatically, with a specialized + * map creator given the size hint. + * + * @param callback callback for reading keys and values, calls to {@link + * ReadMapEntryCallback#getKey()} and {@link ReadMapEntryCallback#getValue()} will be in + * alternating orders, i.e. k1, v1, k2, v2, .., k_n, v_n + * @param mapCreator map creator given the size hint + * @param key type + * @param value type + * @return the map read + * @throws TException when any sub-operation fails + */ + public final Map readMap( + ReadMapEntryCallback callback, IntFunction> mapCreator) throws TException { + return readMap( + tMap -> { + Map map = mapCreator.apply(tMap.size); + for (int i = 0; i < tMap.size; i += 1) { + map.put(callback.getKey(), callback.getValue()); + } + return map; + }); + } - public abstract void writeI64(long i64) throws TException; + /** + * read a {@link List} by delegating to the callback, handles {@link #readListBegin() begin} and + * {@link #readListEnd() end} automatically. + * + * @param callback callback for reading the list + * @param result list type + * @return the list read + * @throws TException when any sub-operation fails + */ + public final > T readList(ReadCallback callback) throws TException { + TList tList = readListBegin(); + T t = callback.accept(tList); + readListEnd(); + return t; + } - public abstract void writeDouble(double dub) throws TException; + /** + * read a {@link List} by delegating element reading to the callback, handles {@link + * #readListBegin() begin} and {@link #readListEnd() end} automatically. + * + * @param callback callback for reading one element + * @param element type + * @return list of elements read + * @throws TException when any sub-operation fails + */ + public final List readList(ReadCollectionCallback callback) throws TException { + return readList(callback, ArrayList::new); + } - public abstract void writeString(String str) throws TException; + /** + * read a {@link List} by delegating element reading to the callback, handles {@link + * #readListBegin() begin} and {@link #readListEnd() end} automatically, with a specialized list + * creator given the size hint. + * + * @param callback callback for reading one element + * @param listCreator list creator given size hint + * @param element type + * @return list of elements read + * @throws TException when any sub-operation fails + */ + public final List readList( + ReadCollectionCallback callback, IntFunction> listCreator) throws TException { + return readList( + tList -> { + List list = listCreator.apply(tList.size); + for (int i = 0; i < tList.size; i += 1) { + list.add(callback.call()); + } + return list; + }); + } - public abstract void writeBinary(ByteBuffer buf) throws TException; + /** + * read a {@link Set} of elements by delegating to the callback, handles {@link #readSetBegin() + * begin} and {@link #readSetEnd() end} automatically + * + * @param callback callback for reading the set + * @param result set type + * @return the set read + * @throws TException when any sub-operation fails + */ + public final > T readSet(ReadCallback callback) throws TException { + TSet tSet = readSetBegin(); + T t = callback.accept(tSet); + readSetEnd(); + return t; + } - /** Reading methods. */ - public abstract TMessage readMessageBegin() throws TException; + /** + * read a {@link Set} of elements by delegating element reading to the callback, handles {@link + * #readSetBegin() begin} and {@link #readSetEnd() end} automatically + * + * @param callback callback for reading one element + * @param element type + * @return set of elements read + * @throws TException when any sub-operation fails + */ + public final Set readSet(ReadCollectionCallback callback) throws TException { + return readSet(callback, HashSet::new); + } - public abstract void readMessageEnd() throws TException; - - public abstract TStruct readStructBegin() throws TException; - - public abstract void readStructEnd() throws TException; - - public abstract TField readFieldBegin() throws TException; - - public abstract void readFieldEnd() throws TException; - - public abstract TMap readMapBegin() throws TException; - - public abstract void readMapEnd() throws TException; - - public abstract TList readListBegin() throws TException; - - public abstract void readListEnd() throws TException; - - public abstract TSet readSetBegin() throws TException; - - public abstract void readSetEnd() throws TException; - - public abstract boolean readBool() throws TException; - - public abstract byte readByte() throws TException; - - public abstract short readI16() throws TException; - - public abstract int readI32() throws TException; - - public abstract long readI64() throws TException; - - public abstract double readDouble() throws TException; - - public abstract String readString() throws TException; - - public abstract ByteBuffer readBinary() throws TException; + /** + * read a {@link Set} of elements by delegating element reading to the callback, handles {@link + * #readSetBegin() begin} and {@link #readSetEnd() end} automatically, with a specialized set + * creator given the size hint. + * + * @param callback callback for reading one elment + * @param setCreator set creator given size hint + * @param element type + * @return set of elements read + * @throws TException when any sub-operation fails + */ + public final Set readSet( + ReadCollectionCallback callback, IntFunction> setCreator) throws TException { + return readSet( + tSet -> { + Set set = setCreator.apply(tSet.size); + for (int i = 0; i < tSet.size; i += 1) { + set.add(callback.call()); + } + return set; + }); + } /** * Reset any internal state back to a blank slate. This method only needs to be implemented for diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolDecorator.java b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolDecorator.java index 3b3205a98..668ebceba 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolDecorator.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolDecorator.java @@ -46,166 +46,207 @@ public abstract class TProtocolDecorator extends TProtocol { concreteProtocol = protocol; } + @Override public void writeMessageBegin(TMessage tMessage) throws TException { concreteProtocol.writeMessageBegin(tMessage); } + @Override public void writeMessageEnd() throws TException { concreteProtocol.writeMessageEnd(); } + @Override public void writeStructBegin(TStruct tStruct) throws TException { concreteProtocol.writeStructBegin(tStruct); } + @Override public void writeStructEnd() throws TException { concreteProtocol.writeStructEnd(); } + @Override public void writeFieldBegin(TField tField) throws TException { concreteProtocol.writeFieldBegin(tField); } + @Override public void writeFieldEnd() throws TException { concreteProtocol.writeFieldEnd(); } + @Override public void writeFieldStop() throws TException { concreteProtocol.writeFieldStop(); } + @Override public void writeMapBegin(TMap tMap) throws TException { concreteProtocol.writeMapBegin(tMap); } + @Override public void writeMapEnd() throws TException { concreteProtocol.writeMapEnd(); } + @Override public void writeListBegin(TList tList) throws TException { concreteProtocol.writeListBegin(tList); } + @Override public void writeListEnd() throws TException { concreteProtocol.writeListEnd(); } + @Override public void writeSetBegin(TSet tSet) throws TException { concreteProtocol.writeSetBegin(tSet); } + @Override public void writeSetEnd() throws TException { concreteProtocol.writeSetEnd(); } + @Override public void writeBool(boolean b) throws TException { concreteProtocol.writeBool(b); } + @Override public void writeByte(byte b) throws TException { concreteProtocol.writeByte(b); } + @Override public void writeI16(short i) throws TException { concreteProtocol.writeI16(i); } + @Override public void writeI32(int i) throws TException { concreteProtocol.writeI32(i); } + @Override public void writeI64(long l) throws TException { concreteProtocol.writeI64(l); } + @Override public void writeDouble(double v) throws TException { concreteProtocol.writeDouble(v); } + @Override public void writeString(String s) throws TException { concreteProtocol.writeString(s); } + @Override public void writeBinary(ByteBuffer buf) throws TException { concreteProtocol.writeBinary(buf); } + @Override public TMessage readMessageBegin() throws TException { return concreteProtocol.readMessageBegin(); } + @Override public void readMessageEnd() throws TException { concreteProtocol.readMessageEnd(); } + @Override public TStruct readStructBegin() throws TException { return concreteProtocol.readStructBegin(); } + @Override public void readStructEnd() throws TException { concreteProtocol.readStructEnd(); } + @Override public TField readFieldBegin() throws TException { return concreteProtocol.readFieldBegin(); } + @Override public void readFieldEnd() throws TException { concreteProtocol.readFieldEnd(); } + @Override public TMap readMapBegin() throws TException { return concreteProtocol.readMapBegin(); } + @Override public void readMapEnd() throws TException { concreteProtocol.readMapEnd(); } + @Override public TList readListBegin() throws TException { return concreteProtocol.readListBegin(); } + @Override public void readListEnd() throws TException { concreteProtocol.readListEnd(); } + @Override public TSet readSetBegin() throws TException { return concreteProtocol.readSetBegin(); } + @Override public void readSetEnd() throws TException { concreteProtocol.readSetEnd(); } + @Override public boolean readBool() throws TException { return concreteProtocol.readBool(); } + @Override public byte readByte() throws TException { return concreteProtocol.readByte(); } + @Override public short readI16() throws TException { return concreteProtocol.readI16(); } + @Override public int readI32() throws TException { return concreteProtocol.readI32(); } + @Override public long readI64() throws TException { return concreteProtocol.readI64(); } + @Override public double readDouble() throws TException { return concreteProtocol.readDouble(); } + @Override public String readString() throws TException { return concreteProtocol.readString(); } + @Override public ByteBuffer readBinary() throws TException { return concreteProtocol.readBinary(); } @@ -216,6 +257,7 @@ public abstract class TProtocolDecorator extends TProtocol { * @return * @throws TException */ + @Override public int getMinSerializedSize(byte type) throws TException { return concreteProtocol.getMinSerializedSize(type); } diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolFactory.java b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolFactory.java index b99725320..00b9fa64d 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolFactory.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolFactory.java @@ -24,5 +24,5 @@ import org.apache.thrift.transport.TTransport; /** Factory interface for constructing protocol instances. */ public interface TProtocolFactory extends Serializable { - public TProtocol getProtocol(TTransport trans); + TProtocol getProtocol(TTransport trans); } diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolUtil.java b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolUtil.java index 86cfb459a..a9f566e3a 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolUtil.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TProtocolUtil.java @@ -24,6 +24,9 @@ import org.apache.thrift.TException; /** Utility class with static methods for interacting with protocol data streams. */ public class TProtocolUtil { + // no instantiation + private TProtocolUtil() {} + /** The maximum recursive depth the skip() function will traverse before throwing a TException. */ private static int maxSkipDepth = Integer.MAX_VALUE; diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TReadProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TReadProtocol.java new file mode 100644 index 000000000..d21c9a08f --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TReadProtocol.java @@ -0,0 +1,47 @@ +package org.apache.thrift.protocol; + +import java.nio.ByteBuffer; +import org.apache.thrift.TException; + +public interface TReadProtocol { + + TMessage readMessageBegin() throws TException; + + void readMessageEnd() throws TException; + + TStruct readStructBegin() throws TException; + + void readStructEnd() throws TException; + + TField readFieldBegin() throws TException; + + void readFieldEnd() throws TException; + + TMap readMapBegin() throws TException; + + void readMapEnd() throws TException; + + TList readListBegin() throws TException; + + void readListEnd() throws TException; + + TSet readSetBegin() throws TException; + + void readSetEnd() throws TException; + + boolean readBool() throws TException; + + byte readByte() throws TException; + + short readI16() throws TException; + + int readI32() throws TException; + + long readI64() throws TException; + + double readDouble() throws TException; + + String readString() throws TException; + + ByteBuffer readBinary() throws TException; +} diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TSet.java b/lib/java/src/main/java/org/apache/thrift/protocol/TSet.java index 1d1651e7f..3443cac06 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TSet.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TSet.java @@ -36,4 +36,12 @@ public final class TSet { public final byte elemType; public final int size; + + public byte getElemType() { + return elemType; + } + + public int getSize() { + return size; + } } diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TSimpleJSONProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TSimpleJSONProtocol.java index 506572b59..0b948873a 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TSimpleJSONProtocol.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TSimpleJSONProtocol.java @@ -468,6 +468,7 @@ public class TSimpleJSONProtocol extends TProtocol { } /** Return the minimum number of bytes a type will consume on the wire */ + @Override public int getMinSerializedSize(byte type) throws TException { switch (type) { case 0: diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TStruct.java b/lib/java/src/main/java/org/apache/thrift/protocol/TStruct.java index 559c61413..a30fb5de2 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TStruct.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TStruct.java @@ -30,4 +30,8 @@ public final class TStruct { } public final String name; + + public String getName() { + return name; + } } diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TTupleProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TTupleProtocol.java index de47d6f0a..c4141a605 100644 --- a/lib/java/src/main/java/org/apache/thrift/protocol/TTupleProtocol.java +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TTupleProtocol.java @@ -28,6 +28,7 @@ public final class TTupleProtocol extends TCompactProtocol { public static class Factory implements TProtocolFactory { public Factory() {} + @Override public TProtocol getProtocol(TTransport trans) { return new TTupleProtocol(trans); } @@ -110,9 +111,12 @@ public final class TTupleProtocol extends TCompactProtocol { return new TSet(readListBegin(type)); } + @Override public void readMapEnd() throws TException {} + @Override public void readListEnd() throws TException {} + @Override public void readSetEnd() throws TException {} } diff --git a/lib/java/src/main/java/org/apache/thrift/protocol/TWriteProtocol.java b/lib/java/src/main/java/org/apache/thrift/protocol/TWriteProtocol.java new file mode 100644 index 000000000..339a6b8e3 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/protocol/TWriteProtocol.java @@ -0,0 +1,49 @@ +package org.apache.thrift.protocol; + +import java.nio.ByteBuffer; +import org.apache.thrift.TException; + +public interface TWriteProtocol { + + void writeMessageBegin(TMessage message) throws TException; + + void writeMessageEnd() throws TException; + + void writeStructBegin(TStruct struct) throws TException; + + void writeStructEnd() throws TException; + + void writeFieldBegin(TField field) throws TException; + + void writeFieldEnd() throws TException; + + void writeFieldStop() throws TException; + + void writeMapBegin(TMap map) throws TException; + + void writeMapEnd() throws TException; + + void writeListBegin(TList list) throws TException; + + void writeListEnd() throws TException; + + void writeSetBegin(TSet set) throws TException; + + void writeSetEnd() throws TException; + + void writeBool(boolean b) throws TException; + + void writeByte(byte b) throws TException; + + void writeI16(short i16) throws TException; + + void writeI32(int i32) throws TException; + + void writeI64(long i64) throws TException; + + void writeDouble(double dub) throws TException; + + void writeString(String str) throws TException; + + void writeBinary(ByteBuffer buf) throws TException; +}