THRIFT-5237 Implement MAX_MESSAGE_SIZE and consolidate limits into a TConfiguration class

Client: cpp
Patch: Zezeng Wang

This closes #2185
This commit is contained in:
zeshuai007 2020-06-15 17:00:33 +08:00 committed by Jens Geyer
parent 23c8e52fa0
commit 86352b4821
46 changed files with 1005 additions and 152 deletions

View File

@ -141,7 +141,8 @@ include_thrift_HEADERS = \
src/thrift/TApplicationException.h \
src/thrift/TLogging.h \
src/thrift/TToString.h \
src/thrift/TBase.h
src/thrift/TBase.h \
src/thrift/TConfiguration.h
include_concurrencydir = $(include_thriftdir)/concurrency
include_concurrency_HEADERS = \
@ -156,6 +157,10 @@ include_concurrency_HEADERS = \
include_protocoldir = $(include_thriftdir)/protocol
include_protocol_HEADERS = \
src/thrift/protocol/TEnum.h \
src/thrift/protocol/TList.h \
src/thrift/protocol/TSet.h \
src/thrift/protocol/TMap.h \
src/thrift/protocol/TBinaryProtocol.h \
src/thrift/protocol/TBinaryProtocol.tcc \
src/thrift/protocol/TCompactProtocol.h \

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef THRIFT_TCONFIGURATION_H
#define THRIFT_TCONFIGURATION_H
namespace apache {
namespace thrift {
class TConfiguration
{
public:
TConfiguration(int maxMessageSize = DEFAULT_MAX_MESSAGE_SIZE,
int maxFrameSize = DEFAULT_MAX_FRAME_SIZE, int recursionLimit = DEFAULT_RECURSION_DEPTH)
: maxMessageSize_(maxMessageSize), maxFrameSize_(maxFrameSize), recursionLimit_(recursionLimit) {}
const static int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
const static int DEFAULT_MAX_FRAME_SIZE = 16384000; // this value is used consistently across all Thrift libraries
const static int DEFAULT_RECURSION_DEPTH = 64;
inline int getMaxMessageSize() { return maxMessageSize_; }
inline void setMaxMessageSize(int maxMessageSize) { maxMessageSize_ = maxMessageSize; }
inline int getMaxFrameSize() { return maxFrameSize_; }
inline void setMaxFrameSize(int maxFrameSize) { maxFrameSize_ = maxFrameSize; }
inline int getRecursionLimit() { return recursionLimit_; }
inline void setRecursionLimit(int recursionLimit) { recursionLimit_ = recursionLimit; }
private:
int maxMessageSize_ = DEFAULT_MAX_MESSAGE_SIZE;
int maxFrameSize_ = DEFAULT_MAX_FRAME_SIZE;
int recursionLimit_ = DEFAULT_RECURSION_DEPTH;
// TODO(someone_smart): add connection and i/o timeouts
};
}
} // apache::thrift
#endif /* THRIFT_TCONFIGURATION_H */

View File

@ -166,6 +166,24 @@ public:
inline uint32_t readBinary(std::string& str);
int getMinSerializedSize(TType type);
void checkReadBytesAvailable(TSet& set)
{
trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
}
void checkReadBytesAvailable(TList& list)
{
trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
}
void checkReadBytesAvailable(TMap& map)
{
int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
trans_->checkReadBytesAvailable(map.size_ * elmSize);
}
protected:
template <typename StrType>
uint32_t readStringBody(StrType& str, int32_t sz);

View File

@ -21,6 +21,7 @@
#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_TCC_ 1
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/transport/TTransportException.h>
#include <limits>
@ -285,6 +286,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readMapBegin(TType& keyType,
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
TMap map(keyType, valType, size);
checkReadBytesAvailable(map);
return result;
}
@ -307,6 +312,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readListBegin(TType& elemType
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
TList list(elemType, size);
checkReadBytesAvailable(list);
return result;
}
@ -329,6 +338,10 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readSetBegin(TType& elemType,
throw TProtocolException(TProtocolException::SIZE_LIMIT);
}
size = (uint32_t)sizei;
TSet set(elemType, size);
checkReadBytesAvailable(set);
return result;
}
@ -447,6 +460,30 @@ uint32_t TBinaryProtocolT<Transport_, ByteOrder_>::readStringBody(StrType& str,
this->trans_->readAll(reinterpret_cast<uint8_t*>(&str[0]), size);
return (uint32_t)size;
}
// Return the minimum number of bytes a type will consume on the wire
template <class Transport_, class ByteOrder_>
int TBinaryProtocolT<Transport_, ByteOrder_>::getMinSerializedSize(TType type)
{
switch (type)
{
case T_STOP: return 0;
case T_VOID: return 0;
case T_BOOL: return sizeof(int8_t);
case T_BYTE: return sizeof(int8_t);
case T_DOUBLE: return sizeof(double);
case T_I16: return sizeof(short);
case T_I32: return sizeof(int);
case T_I64: return sizeof(long);
case T_STRING: return sizeof(int); // string length
case T_STRUCT: return 0; // empty struct
case T_MAP: return sizeof(int); // element count
case T_SET: return sizeof(int); // element count
case T_LIST: return sizeof(int); // element count
default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
}
}
}
}
} // apache::thrift::protocol

View File

@ -140,6 +140,24 @@ public:
uint32_t writeBinary(const std::string& str);
int getMinSerializedSize(TType type);
void checkReadBytesAvailable(TSet& set)
{
trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
}
void checkReadBytesAvailable(TList& list)
{
trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
}
void checkReadBytesAvailable(TMap& map)
{
int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
trans_->checkReadBytesAvailable(map.size_ * elmSize);
}
/**
* These methods are called by structs, but don't actually have any wired
* output or purpose

View File

@ -538,6 +538,9 @@ uint32_t TCompactProtocolT<Transport_>::readMapBegin(TType& keyType,
valType = getTType((int8_t)((uint8_t)kvType & 0xf));
size = (uint32_t)msize;
TMap map(keyType, valType, size);
checkReadBytesAvailable(map);
return rsize;
}
@ -570,6 +573,9 @@ uint32_t TCompactProtocolT<Transport_>::readListBegin(TType& elemType,
elemType = getTType((int8_t)(size_and_type & 0x0f));
size = (uint32_t)lsize;
TList list(elemType, size);
checkReadBytesAvailable(list);
return rsize;
}
@ -706,6 +712,8 @@ uint32_t TCompactProtocolT<Transport_>::readBinary(std::string& str) {
trans_->readAll(string_buf_, size);
str.assign((char*)string_buf_, size);
trans_->checkReadBytesAvailable(rsize + (uint32_t)size);
return rsize + (uint32_t)size;
}
@ -821,6 +829,30 @@ TType TCompactProtocolT<Transport_>::getTType(int8_t type) {
}
}
// Return the minimum number of bytes a type will consume on the wire
template <class Transport_>
int TCompactProtocolT<Transport_>::getMinSerializedSize(TType type)
{
switch (type)
{
case T_STOP: return 0;
case T_VOID: return 0;
case T_BOOL: return sizeof(int8_t);
case T_DOUBLE: return 8; // uses fixedLongToBytes() which always writes 8 bytes
case T_BYTE: return sizeof(int8_t);
case T_I16: return sizeof(int8_t); // zigzag
case T_I32: return sizeof(int8_t); // zigzag
case T_I64: return sizeof(int8_t); // zigzag
case T_STRING: return sizeof(int8_t); // string length
case T_STRUCT: return 0; // empty struct
case T_MAP: return sizeof(int8_t); // element count
case T_SET: return sizeof(int8_t); // element count
case T_LIST: return sizeof(int8_t); // element count
default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
}
}
}}} // apache::thrift::protocol
#endif // _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_TCC_

View File

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_ENUM_H_
#define _THRIFT_ENUM_H_
namespace apache {
namespace thrift {
namespace protocol {
/**
* Enumerated definition of the types that the Thrift protocol supports.
* Take special note of the T_END type which is used specifically to mark
* the end of a sequence of fields.
*/
enum TType {
T_STOP = 0,
T_VOID = 1,
T_BOOL = 2,
T_BYTE = 3,
T_I08 = 3,
T_I16 = 6,
T_I32 = 8,
T_U64 = 9,
T_I64 = 10,
T_DOUBLE = 4,
T_STRING = 11,
T_UTF7 = 11,
T_STRUCT = 12,
T_MAP = 13,
T_SET = 14,
T_LIST = 15,
T_UTF8 = 16,
T_UTF16 = 17
};
/**
* Enumerated definition of the message types that the Thrift protocol
* supports.
*/
enum TMessageType {
T_CALL = 1,
T_REPLY = 2,
T_EXCEPTION = 3,
T_ONEWAY = 4
};
}}} // apache::thrift::protocol
#endif // #define _THRIFT_ENUM_H_

View File

@ -1013,6 +1013,10 @@ uint32_t TJSONProtocol::readMapBegin(TType& keyType, TType& valType, uint32_t& s
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
result += readJSONObjectStart();
TMap map(keyType, valType, size);
checkReadBytesAvailable(map);
return result;
}
@ -1032,6 +1036,10 @@ uint32_t TJSONProtocol::readListBegin(TType& elemType, uint32_t& size) {
if (tmpVal > (std::numeric_limits<uint32_t>::max)())
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
TList list(elemType, size);
checkReadBytesAvailable(list);
return result;
}
@ -1049,6 +1057,10 @@ uint32_t TJSONProtocol::readSetBegin(TType& elemType, uint32_t& size) {
if (tmpVal > (std::numeric_limits<uint32_t>::max)())
throw TProtocolException(TProtocolException::SIZE_LIMIT);
size = static_cast<uint32_t>(tmpVal);
TSet set(elemType, size);
checkReadBytesAvailable(set);
return result;
}
@ -1093,6 +1105,29 @@ uint32_t TJSONProtocol::readString(std::string& str) {
uint32_t TJSONProtocol::readBinary(std::string& str) {
return readJSONBase64(str);
}
// Return the minimum number of bytes a type will consume on the wire
int TJSONProtocol::getMinSerializedSize(TType type)
{
switch (type)
{
case T_STOP: return 0;
case T_VOID: return 0;
case T_BOOL: return 1; // written as int
case T_BYTE: return 1;
case T_DOUBLE: return 1;
case T_I16: return 1;
case T_I32: return 1;
case T_I64: return 1;
case T_STRING: return 2; // empty string
case T_STRUCT: return 2; // empty struct
case T_MAP: return 2; // empty map
case T_SET: return 2; // empty set
case T_LIST: return 2; // empty list
default: throw TProtocolException(TProtocolException::UNKNOWN, "unrecognized type code");
}
}
}
}
} // apache::thrift::protocol

View File

@ -245,6 +245,24 @@ public:
uint32_t readBinary(std::string& str);
int getMinSerializedSize(TType type);
void checkReadBytesAvailable(TSet& set)
{
trans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
}
void checkReadBytesAvailable(TList& list)
{
trans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
}
void checkReadBytesAvailable(TMap& map)
{
int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
trans_->checkReadBytesAvailable(map.size_ * elmSize);
}
class LookaheadReader {
public:

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_TLIST_H_
#define _THRIFT_TLIST_H_
#include <thrift/protocol/TEnum.h>
namespace apache {
namespace thrift {
namespace protocol {
// using namespace apache::thrift::protocol;
/**
* Helper class that encapsulates list metadata.
*
*/
class TList {
public:
TList() : elemType_(T_STOP),
size_(0) {
}
TList(TType t = T_STOP, int s = 0)
: elemType_(t),
size_(s) {
}
TType elemType_;
int size_;
};
}
}
} // apache::thrift::protocol
#endif // #ifndef _THRIFT_TLIST_H_

View File

@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_TMAP_H_
#define _THRIFT_TMAP_H_
#include <thrift/protocol/TEnum.h>
namespace apache {
namespace thrift {
namespace protocol {
using namespace apache::thrift::protocol;
/**
* Helper class that encapsulates map metadata.
*
*/
class TMap {
public:
TMap()
: keyType_(T_STOP),
valueType_(T_STOP),
size_(0) {
}
TMap(TType k, TType v, int s)
: keyType_(k),
valueType_(v),
size_(s) {
}
TType keyType_;
TType valueType_;
int size_;
};
}
}
} // apache::thrift::protocol
#endif // #ifndef _THRIFT_TMAP_H_

View File

@ -27,6 +27,10 @@
#include <thrift/transport/TTransport.h>
#include <thrift/protocol/TProtocolException.h>
#include <thrift/protocol/TEnum.h>
#include <thrift/protocol/TList.h>
#include <thrift/protocol/TSet.h>
#include <thrift/protocol/TMap.h>
#include <memory>
@ -170,45 +174,6 @@ namespace protocol {
using apache::thrift::transport::TTransport;
/**
* Enumerated definition of the types that the Thrift protocol supports.
* Take special note of the T_END type which is used specifically to mark
* the end of a sequence of fields.
*/
enum TType {
T_STOP = 0,
T_VOID = 1,
T_BOOL = 2,
T_BYTE = 3,
T_I08 = 3,
T_I16 = 6,
T_I32 = 8,
T_U64 = 9,
T_I64 = 10,
T_DOUBLE = 4,
T_STRING = 11,
T_UTF7 = 11,
T_STRUCT = 12,
T_MAP = 13,
T_SET = 14,
T_LIST = 15,
T_UTF8 = 16,
T_UTF16 = 17
};
/**
* Enumerated definition of the message types that the Thrift protocol
* supports.
*/
enum TMessageType {
T_CALL = 1,
T_REPLY = 2,
T_EXCEPTION = 3,
T_ONEWAY = 4
};
static const uint32_t DEFAULT_RECURSION_LIMIT = 64;
/**
* Abstract class for a thrift protocol driver. These are all the methods that
* a protocol must implement. Essentially, there must be some way of reading
@ -578,11 +543,34 @@ public:
uint32_t getRecursionLimit() const {return recursion_limit_;}
void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}
// Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
virtual int getMinSerializedSize(TType type) {
THRIFT_UNUSED_VARIABLE(type);
return 0;
}
protected:
TProtocol(std::shared_ptr<TTransport> ptrans)
: ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
: ptrans_(ptrans), input_recursion_depth_(0), output_recursion_depth_(0),
recursion_limit_(ptrans->getConfiguration()->getRecursionLimit())
{}
virtual void checkReadBytesAvailable(TSet& set)
{
ptrans_->checkReadBytesAvailable(set.size_ * getMinSerializedSize(set.elemType_));
}
virtual void checkReadBytesAvailable(TList& list)
{
ptrans_->checkReadBytesAvailable(list.size_ * getMinSerializedSize(list.elemType_));
}
virtual void checkReadBytesAvailable(TMap& map)
{
int elmSize = getMinSerializedSize(map.keyType_) + getMinSerializedSize(map.valueType_);
ptrans_->checkReadBytesAvailable(map.size_ * elmSize);
}
std::shared_ptr<TTransport> ptrans_;
private:

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef _THRIFT_TSET_H_
#define _THRIFT_TSET_H_
#include <thrift/protocol/TEnum.h>
#include <thrift/protocol/TList.h>
namespace apache {
namespace thrift {
namespace protocol {
using namespace apache::thrift::protocol;
/**
* Helper class that encapsulates set metadata.
*
*/
class TSet {
public:
TSet() : elemType_(T_STOP), size_(0) {
}
TSet(TType t, int s)
: elemType_(t),
size_(s) {
}
TSet(TList list)
: elemType_(list.elemType_),
size_(list.size_) {
}
TType elemType_;
int size_;
};
}
}
} // apache::thrift::protocol
#endif // #ifndef _THRIFT_TSET_H_

View File

@ -118,6 +118,7 @@ const uint8_t* TBufferedTransport::borrowSlow(uint8_t* buf, uint32_t* len) {
}
void TBufferedTransport::flush() {
resetConsumedMessageSize();
// Write out any data waiting in the write buffer.
auto have_bytes = static_cast<uint32_t>(wBase_ - wBuf_.get());
if (have_bytes > 0) {
@ -248,6 +249,7 @@ void TFramedTransport::writeSlow(const uint8_t* buf, uint32_t len) {
}
void TFramedTransport::flush() {
resetConsumedMessageSize();
int32_t sz_hbo, sz_nbo;
assert(wBufSize_ > sizeof(sz_nbo));

View File

@ -62,6 +62,7 @@ public:
* This method is meant to eventually be nonvirtual and inlinable.
*/
uint32_t read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
uint8_t* new_rBase = rBase_ + len;
if (TDB_LIKELY(new_rBase <= rBound_)) {
std::memcpy(buf, rBase_, len);
@ -120,6 +121,7 @@ public:
* Consume doesn't require a slow path.
*/
void consume(uint32_t len) {
countConsumedMessageBytes(len);
if (TDB_LIKELY(static_cast<ptrdiff_t>(len) <= rBound_ - rBase_)) {
rBase_ += len;
} else {
@ -148,7 +150,8 @@ protected:
* performance-sensitive operation, so it is okay to just leave it to
* the concrete class to set up pointers correctly.
*/
TBufferBase() : rBase_(nullptr), rBound_(nullptr), wBase_(nullptr), wBound_(nullptr) {}
TBufferBase(std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config), rBase_(nullptr), rBound_(nullptr), wBase_(nullptr), wBound_(nullptr) {}
/// Convenience mutator for setting the read buffer.
void setReadBuffer(uint8_t* buf, uint32_t len) {
@ -186,8 +189,9 @@ public:
static const int DEFAULT_BUFFER_SIZE = 512;
/// Use default buffer sizes.
TBufferedTransport(std::shared_ptr<TTransport> transport)
: transport_(transport),
TBufferedTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config),
transport_(transport),
rBufSize_(DEFAULT_BUFFER_SIZE),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(new uint8_t[rBufSize_]),
@ -196,8 +200,9 @@ public:
}
/// Use specified buffer sizes.
TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t sz)
: transport_(transport),
TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t sz, std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config),
transport_(transport),
rBufSize_(sz),
wBufSize_(sz),
rBuf_(new uint8_t[rBufSize_]),
@ -206,8 +211,10 @@ public:
}
/// Use specified read and write buffer sizes.
TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz)
: transport_(transport),
TBufferedTransport(std::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz,
std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config),
transport_(transport),
rBufSize_(rsz),
wBufSize_(wsz),
rBuf_(new uint8_t[rBufSize_]),
@ -309,8 +316,9 @@ public:
static const int DEFAULT_MAX_FRAME_SIZE = 256 * 1024 * 1024;
/// Use default buffer sizes.
TFramedTransport()
: transport_(),
TFramedTransport(std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config),
transport_(),
rBufSize_(0),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(),
@ -319,27 +327,30 @@ public:
initPointers();
}
TFramedTransport(std::shared_ptr<TTransport> transport)
: transport_(transport),
TFramedTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config),
transport_(transport),
rBufSize_(0),
wBufSize_(DEFAULT_BUFFER_SIZE),
rBuf_(),
wBuf_(new uint8_t[wBufSize_]),
bufReclaimThresh_((std::numeric_limits<uint32_t>::max)()),
maxFrameSize_(DEFAULT_MAX_FRAME_SIZE) {
maxFrameSize_(configuration_->getMaxFrameSize()) {
initPointers();
}
TFramedTransport(std::shared_ptr<TTransport> transport,
uint32_t sz,
uint32_t bufReclaimThresh = (std::numeric_limits<uint32_t>::max)())
: transport_(transport),
uint32_t bufReclaimThresh = (std::numeric_limits<uint32_t>::max)(),
std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config),
transport_(transport),
rBufSize_(0),
wBufSize_(sz),
rBuf_(),
wBuf_(new uint8_t[wBufSize_]),
bufReclaimThresh_(bufReclaimThresh),
maxFrameSize_(DEFAULT_MAX_FRAME_SIZE) {
maxFrameSize_(configuration_->getMaxFrameSize()) {
initPointers();
}
@ -503,7 +514,10 @@ public:
* Construct a TMemoryBuffer with a default-sized buffer,
* owned by the TMemoryBuffer object.
*/
TMemoryBuffer() { initCommon(nullptr, defaultSize, true, 0); }
TMemoryBuffer(std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config) {
initCommon(nullptr, defaultSize, true, 0);
}
/**
* Construct a TMemoryBuffer with a buffer of a specified size,
@ -511,7 +525,10 @@ public:
*
* @param sz The initial size of the buffer.
*/
TMemoryBuffer(uint32_t sz) { initCommon(nullptr, sz, true, 0); }
TMemoryBuffer(uint32_t sz, std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config) {
initCommon(nullptr, sz, true, 0);
}
/**
* Construct a TMemoryBuffer with buf as its initial contents.
@ -523,7 +540,8 @@ public:
* @param sz The size of @c buf.
* @param policy See @link MemoryPolicy @endlink .
*/
TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) {
TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE, std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config) {
if (buf == nullptr && sz != 0) {
throw TTransportException(TTransportException::BAD_ARGS,
"TMemoryBuffer given null buffer with non-zero size.");

View File

@ -52,6 +52,7 @@ void TFDTransport::close() {
}
uint32_t TFDTransport::read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
unsigned int maxRetries = 5; // same as the TSocket default
unsigned int retries = 0;
while (true) {

View File

@ -40,8 +40,10 @@ class TFDTransport : public TVirtualTransport<TFDTransport> {
public:
enum ClosePolicy { NO_CLOSE_ON_DESTROY = 0, CLOSE_ON_DESTROY = 1 };
TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY)
: fd_(fd), close_policy_(close_policy) {}
TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY,
std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config), fd_(fd), close_policy_(close_policy) {
}
~TFDTransport() override {
if (close_policy_ == CLOSE_ON_DESTROY) {

View File

@ -63,8 +63,9 @@ using std::string;
using namespace apache::thrift::protocol;
using namespace apache::thrift::concurrency;
TFileTransport::TFileTransport(string path, bool readOnly)
: readState_(),
TFileTransport::TFileTransport(string path, bool readOnly, std::shared_ptr<TConfiguration> config)
: TTransport(config),
readState_(),
readBuff_(nullptr),
currentEvent_(nullptr),
readBuffSize_(DEFAULT_READ_BUFF_SIZE),
@ -519,6 +520,7 @@ void TFileTransport::writerThread() {
}
void TFileTransport::flush() {
resetConsumedMessageSize();
// file must be open for writing for any flushing to take place
if (!writerThread_.get()) {
return;
@ -537,6 +539,7 @@ void TFileTransport::flush() {
}
uint32_t TFileTransport::readAll(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
uint32_t have = 0;
uint32_t get = 0;
@ -568,6 +571,7 @@ bool TFileTransport::peek() {
}
uint32_t TFileTransport::read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
// check if there an event is ready to be read
if (!currentEvent_) {
currentEvent_ = readEvent();

View File

@ -173,7 +173,7 @@ public:
*/
class TFileTransport : public TFileReaderTransport, public TFileWriterTransport {
public:
TFileTransport(std::string path, bool readOnly = false);
TFileTransport(std::string path, bool readOnly = false, std::shared_ptr<TConfiguration> config = nullptr);
~TFileTransport() override;
// TODO: what is the correct behaviour for this?

View File

@ -415,6 +415,7 @@ void THeaderTransport::clearHeaders() {
}
void THeaderTransport::flush() {
resetConsumedMessageSize();
// Write out any data waiting in the write buffer.
uint32_t haveBytes = getWriteBytes();

View File

@ -74,8 +74,9 @@ public:
static const int THRIFT_MAX_VARINT32_BYTES = 5;
/// Use default buffer sizes.
explicit THeaderTransport(const std::shared_ptr<TTransport>& transport)
: TVirtualTransport(transport),
explicit THeaderTransport(const std::shared_ptr<TTransport>& transport,
std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(transport, config),
outTransport_(transport),
protoId(T_COMPACT_PROTOCOL),
clientType(THRIFT_HEADER_CLIENT_TYPE),
@ -88,8 +89,9 @@ public:
}
THeaderTransport(const std::shared_ptr<TTransport> inTransport,
const std::shared_ptr<TTransport> outTransport)
: TVirtualTransport(inTransport),
const std::shared_ptr<TTransport> outTransport,
std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(inTransport, config),
outTransport_(outTransport),
protoId(T_COMPACT_PROTOCOL),
clientType(THRIFT_HEADER_CLIENT_TYPE),

View File

@ -34,12 +34,16 @@ namespace transport {
THttpClient::THttpClient(std::shared_ptr<TTransport> transport,
std::string host,
std::string path)
: THttpTransport(transport), host_(host), path_(path) {
std::string path,
std::shared_ptr<TConfiguration> config)
: THttpTransport(transport, config),
host_(host),
path_(path) {
}
THttpClient::THttpClient(string host, int port, string path)
: THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port))),
THttpClient::THttpClient(string host, int port, string path,
std::shared_ptr<TConfiguration> config)
: THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port)), config),
host_(host),
path_(path) {
}
@ -93,6 +97,7 @@ bool THttpClient::parseStatusLine(char* status) {
}
void THttpClient::flush() {
resetConsumedMessageSize();
// Fetch the contents of the write buffer
uint8_t* buf;
uint32_t len;

View File

@ -40,13 +40,16 @@ public:
*/
THttpClient(std::shared_ptr<TTransport> transport,
std::string host = "localhost",
std::string path = "/service");
std::string path = "/service",
std::shared_ptr<TConfiguration> config = nullptr);
/**
* @brief Constructor that will create a new socket transport using the host
* and port.
*/
THttpClient(std::string host, int port, std::string path = "");
THttpClient(std::string host, int port,
std::string path = "",
std::shared_ptr<TConfiguration> config = nullptr);
~THttpClient() override;

View File

@ -34,7 +34,9 @@ namespace apache {
namespace thrift {
namespace transport {
THttpServer::THttpServer(std::shared_ptr<TTransport> transport) : THttpTransport(transport) {
THttpServer::THttpServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config)
: THttpTransport(transport, config) {
}
THttpServer::~THttpServer() = default;
@ -118,6 +120,7 @@ bool THttpServer::parseStatusLine(char* status) {
}
void THttpServer::flush() {
resetConsumedMessageSize();
// Fetch the contents of the write buffer
uint8_t* buf;
uint32_t len;

View File

@ -28,7 +28,7 @@ namespace transport {
class THttpServer : public THttpTransport {
public:
THttpServer(std::shared_ptr<TTransport> transport);
THttpServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr);
~THttpServer() override;

View File

@ -31,8 +31,9 @@ namespace transport {
const char* THttpTransport::CRLF = "\r\n";
const int THttpTransport::CRLF_LEN = 2;
THttpTransport::THttpTransport(std::shared_ptr<TTransport> transport)
: transport_(transport),
THttpTransport::THttpTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config)
: TVirtualTransport(config),
transport_(transport),
origin_(""),
readHeaders_(true),
chunked_(false),
@ -61,6 +62,7 @@ THttpTransport::~THttpTransport() {
}
uint32_t THttpTransport::read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
if (readBuffer_.available_read() == 0) {
readBuffer_.resetBuffer();
uint32_t got = readMoreData();

View File

@ -36,7 +36,7 @@ namespace transport {
*/
class THttpTransport : public TVirtualTransport<THttpTransport> {
public:
THttpTransport(std::shared_ptr<TTransport> transport);
THttpTransport(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr);
~THttpTransport() override;
@ -54,7 +54,9 @@ public:
void write(const uint8_t* buf, uint32_t len);
void flush() override = 0;
void flush() override {
resetConsumedMessageSize();
};
const std::string getOrigin() const override;

View File

@ -222,30 +222,35 @@ uint32_t pseudo_sync_read(HANDLE pipe, HANDLE event, uint8_t* buf, uint32_t len)
}
//---- Constructors ----
TPipe::TPipe(TAutoHandle &Pipe)
: impl_(new TWaitableNamedPipeImpl(Pipe)), TimeoutSeconds_(3), isAnonymous_(false) {
TPipe::TPipe(TAutoHandle &Pipe, std::shared_ptr<TConfiguration> config)
: impl_(new TWaitableNamedPipeImpl(Pipe)), TimeoutSeconds_(3),
isAnonymous_(false), TVirtualTransport(config) {
}
TPipe::TPipe(HANDLE Pipe)
: TimeoutSeconds_(3), isAnonymous_(false)
TPipe::TPipe(HANDLE Pipe, std::shared_ptr<TConfiguration> config)
: TimeoutSeconds_(3), isAnonymous_(false), TVirtualTransport(config)
{
TAutoHandle pipeHandle(Pipe);
impl_.reset(new TWaitableNamedPipeImpl(pipeHandle));
}
TPipe::TPipe(const char* pipename) : TimeoutSeconds_(3), isAnonymous_(false) {
TPipe::TPipe(const char* pipename, std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3),
isAnonymous_(false), TVirtualTransport(config) {
setPipename(pipename);
}
TPipe::TPipe(const std::string& pipename) : TimeoutSeconds_(3), isAnonymous_(false) {
TPipe::TPipe(const std::string& pipename, std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3),
isAnonymous_(false), TVirtualTransport(config) {
setPipename(pipename);
}
TPipe::TPipe(HANDLE PipeRd, HANDLE PipeWrt)
: impl_(new TAnonPipeImpl(PipeRd, PipeWrt)), TimeoutSeconds_(3), isAnonymous_(true) {
TPipe::TPipe(HANDLE PipeRd, HANDLE PipeWrt, std::shared_ptr<TConfiguration> config)
: impl_(new TAnonPipeImpl(PipeRd, PipeWrt)), TimeoutSeconds_(3), isAnonymous_(true),
TVirtualTransport(config) {
}
TPipe::TPipe() : TimeoutSeconds_(3), isAnonymous_(false) {
TPipe::TPipe(std::shared_ptr<TConfiguration> config) : TimeoutSeconds_(3), isAnonymous_(false),
TVirtualTransport(config) {
}
TPipe::~TPipe() {
@ -299,6 +304,7 @@ void TPipe::close() {
}
uint32_t TPipe::read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
if (!isOpen())
throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open pipe");
return impl_->read(buf, len);

View File

@ -49,15 +49,15 @@ class TPipeImpl;
class TPipe : public TVirtualTransport<TPipe> {
public:
// Constructs a new pipe object.
TPipe();
TPipe(std::shared_ptr<TConfiguration> config = nullptr);
// Named pipe constructors -
explicit TPipe(HANDLE Pipe); // HANDLE is a void*
explicit TPipe(TAutoHandle& Pipe); // this ctor will clear out / move from Pipe
explicit TPipe(HANDLE Pipe, std::shared_ptr<TConfiguration> config = nullptr); // HANDLE is a void*
explicit TPipe(TAutoHandle& Pipe, std::shared_ptr<TConfiguration> config = nullptr); // this ctor will clear out / move from Pipe
// need a const char * overload so string literals don't go to the HANDLE overload
explicit TPipe(const char* pipename);
explicit TPipe(const std::string& pipename);
explicit TPipe(const char* pipename, std::shared_ptr<TConfiguration> config = nullptr);
explicit TPipe(const std::string& pipename, std::shared_ptr<TConfiguration> config = nullptr);
// Anonymous pipe -
TPipe(HANDLE PipeRd, HANDLE PipeWrt);
TPipe(HANDLE PipeRd, HANDLE PipeWrt, std::shared_ptr<TConfiguration> config = nullptr);
// Destroys the pipe object, closing it if necessary.
virtual ~TPipe();

View File

@ -214,34 +214,37 @@ SSL* SSLContext::createSSL() {
}
// TSSLSocket implementation
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx)
: TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) {
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config)
: TSocket(config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener)
: TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) {
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config)
: TSocket(config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
interruptListener_ = interruptListener;
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket)
: TSocket(socket), server_(false), ssl_(nullptr), ctx_(ctx) {
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config)
: TSocket(socket, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener)
: TSocket(socket, interruptListener), server_(false), ssl_(nullptr), ctx_(ctx) {
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config)
: TSocket(socket, interruptListener, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port)
: TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) {
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<TConfiguration> config)
: TSocket(host, port, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
}
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener)
: TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) {
TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config)
: TSocket(host, port, config), server_(false), ssl_(nullptr), ctx_(ctx) {
init();
interruptListener_ = interruptListener;
}
@ -391,6 +394,7 @@ void TSSLSocket::close() {
* exception incase of failure.
*/
uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
initializeHandshake();
if (!checkHandshake())
throw TTransportException(TTransportException::UNKNOWN, "retry again");
@ -553,6 +557,7 @@ uint32_t TSSLSocket::write_partial(const uint8_t* buf, uint32_t len) {
}
void TSSLSocket::flush() {
resetConsumedMessageSize();
// Don't throw exception if not open. Thrift servers close socket twice.
if (ssl_ == nullptr) {
return;

View File

@ -111,37 +111,40 @@ protected:
/**
* Constructor.
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx);
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor with an interrupt signal.
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener);
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor, create an instance of TSSLSocket given an existing socket.
*
* @param socket An existing socket
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket);
TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted.
*
* @param socket An existing socket
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor.
*
* @param host Remote host name
* @param port Remote port number
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port);
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor with an interrupt signal.
*
* @param host Remote host name
* @param port Remote port number
*/
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener);
TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Authorize peer access after SSL handshake completes.
*/

View File

@ -38,8 +38,10 @@ namespace test {
*/
class TShortReadTransport : public TVirtualTransport<TShortReadTransport> {
public:
TShortReadTransport(std::shared_ptr<TTransport> transport, double full_prob)
: transport_(transport), fullProb_(full_prob) {}
TShortReadTransport(std::shared_ptr<TTransport> transport, double full_prob,
std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config), transport_(transport), fullProb_(full_prob) {
}
bool isOpen() const override { return transport_->isOpen(); }
@ -50,6 +52,7 @@ public:
void close() override { transport_->close(); }
uint32_t read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
if (len == 0) {
return 0;
}
@ -62,11 +65,17 @@ public:
void write(const uint8_t* buf, uint32_t len) { transport_->write(buf, len); }
void flush() override { transport_->flush(); }
void flush() override {
resetConsumedMessageSize();
transport_->flush();
}
const uint8_t* borrow(uint8_t* buf, uint32_t* len) { return transport_->borrow(buf, len); }
void consume(uint32_t len) { return transport_->consume(len); }
void consume(uint32_t len) {
countConsumedMessageBytes(len);
return transport_->consume(len);
}
std::shared_ptr<TTransport> getUnderlyingTransport() { return transport_; }

View File

@ -35,8 +35,8 @@ namespace apache {
namespace thrift {
namespace transport {
TSimpleFileTransport::TSimpleFileTransport(const std::string& path, bool read, bool write)
: TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY) {
TSimpleFileTransport::TSimpleFileTransport(const std::string& path, bool read, bool write, std::shared_ptr<TConfiguration> config)
: TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY, config) {
int flags = 0;
if (read && write) {
flags = O_RDWR;

View File

@ -33,7 +33,8 @@ namespace transport {
*/
class TSimpleFileTransport : public TFDTransport {
public:
TSimpleFileTransport(const std::string& path, bool read = true, bool write = false);
TSimpleFileTransport(const std::string& path, bool read = true, bool write = false,
std::shared_ptr<TConfiguration> config = nullptr);
};
}
}

View File

@ -77,8 +77,9 @@ namespace transport {
*
*/
TSocket::TSocket(const string& host, int port)
: host_(host),
TSocket::TSocket(const string& host, int port, std::shared_ptr<TConfiguration> config)
: TVirtualTransport(config),
host_(host),
port_(port),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
@ -92,8 +93,9 @@ TSocket::TSocket(const string& host, int port)
maxRecvRetries_(5) {
}
TSocket::TSocket(const string& path)
: port_(0),
TSocket::TSocket(const string& path, std::shared_ptr<TConfiguration> config)
: TVirtualTransport(config),
port_(0),
path_(path),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
@ -108,8 +110,9 @@ TSocket::TSocket(const string& path)
cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
}
TSocket::TSocket()
: port_(0),
TSocket::TSocket(std::shared_ptr<TConfiguration> config)
: TVirtualTransport(config),
port_(0),
socket_(THRIFT_INVALID_SOCKET),
peerPort_(0),
connTimeout_(0),
@ -123,8 +126,9 @@ TSocket::TSocket()
cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
}
TSocket::TSocket(THRIFT_SOCKET socket)
: port_(0),
TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config)
: TVirtualTransport(config),
port_(0),
socket_(socket),
peerPort_(0),
connTimeout_(0),
@ -144,8 +148,10 @@ TSocket::TSocket(THRIFT_SOCKET socket)
#endif
}
TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener)
: port_(0),
TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config)
: TVirtualTransport(config),
port_(0),
socket_(socket),
peerPort_(0),
interruptListener_(interruptListener),
@ -522,6 +528,7 @@ void TSocket::setSocketFD(THRIFT_SOCKET socket) {
}
uint32_t TSocket::read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
if (socket_ == THRIFT_INVALID_SOCKET) {
throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
}

View File

@ -52,7 +52,7 @@ public:
* socket.
*
*/
TSocket();
TSocket(std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructs a new socket. Note that this does NOT actually connect the
@ -61,7 +61,7 @@ public:
* @param host An IP address or hostname to connect to
* @param port The port to connect on
*/
TSocket(const std::string& host, int port);
TSocket(const std::string& host, int port, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructs a new Unix domain socket.
@ -69,7 +69,7 @@ public:
*
* @param path The Unix domain socket e.g. "/tmp/ThriftTest.binary.thrift"
*/
TSocket(const std::string& path);
TSocket(const std::string& path, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Destroyes the socket object, closing it if necessary.
@ -264,13 +264,14 @@ public:
/**
* Constructor to create socket from file descriptor.
*/
TSocket(THRIFT_SOCKET socket);
TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr);
/**
* Constructor to create socket from file descriptor that
* can be interrupted safely.
*/
TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
std::shared_ptr<TConfiguration> config = nullptr);
/**
* Set a cache of the peer address (used when trivially available: e.g.

View File

@ -21,6 +21,7 @@
#define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1
#include <thrift/Thrift.h>
#include <thrift/TConfiguration.h>
#include <thrift/transport/TTransportException.h>
#include <memory>
#include <string>
@ -55,6 +56,15 @@ uint32_t readAll(Transport_& trans, uint8_t* buf, uint32_t len) {
*/
class TTransport {
public:
TTransport(std::shared_ptr<TConfiguration> config = nullptr) {
if(config == nullptr) {
configuration_ = std::shared_ptr<TConfiguration> (new TConfiguration());
} else {
configuration_ = config;
}
resetConsumedMessageSize();
}
/**
* Virtual deconstructor.
*/
@ -238,11 +248,87 @@ public:
*/
virtual const std::string getOrigin() const { return "Unknown"; }
protected:
std::shared_ptr<TConfiguration> getConfiguration() { return configuration_; }
void setConfiguration(std::shared_ptr<TConfiguration> config) {
if (config != nullptr) configuration_ = config;
}
/**
* Simple constructor.
* Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
* Will throw if we already consumed too many bytes or if the new size is larger than allowed.
*
* @param size real message size
*/
TTransport() = default;
void updateKnownMessageSize(long int size)
{
long int consumed = knownMessageSize_ - remainingMessageSize_;
resetConsumedMessageSize(size);
countConsumedMessageBytes(consumed);
}
/**
* Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data
*
* @param numBytes numBytes bytes of data
*/
void checkReadBytesAvailable(long int numBytes)
{
if (remainingMessageSize_ < numBytes)
throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
}
protected:
std::shared_ptr<TConfiguration> configuration_;
long int remainingMessageSize_;
long int knownMessageSize_;
inline long int getRemainingMessageSize() { return remainingMessageSize_; }
inline void setRemainingMessageSize(long int remainingMessageSize) { remainingMessageSize_ = remainingMessageSize; }
inline int getMaxMessageSize() { return configuration_->getMaxMessageSize(); }
inline long int getKnownMessageSize() { return knownMessageSize_; }
void setKnownMessageSize(long int knownMessageSize) { knownMessageSize_ = knownMessageSize; }
/**
* Resets RemainingMessageSize to the configured maximum
*
* @param newSize configured size
*/
void resetConsumedMessageSize(long newSize = -1)
{
// full reset
if (newSize < 0)
{
knownMessageSize_ = getMaxMessageSize();
remainingMessageSize_ = getMaxMessageSize();
return;
}
// update only: message size can shrink, but not grow
if (newSize > knownMessageSize_)
throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
knownMessageSize_ = newSize;
remainingMessageSize_ = newSize;
}
/**
* Consumes numBytes from the RemainingMessageSize.
*
* @param numBytes Consumes numBytes
*/
void countConsumedMessageBytes(long int numBytes)
{
if (remainingMessageSize_ >= numBytes)
{
remainingMessageSize_ -= numBytes;
}
else
{
remainingMessageSize_ = 0;
throw new TTransportException(TTransportException::END_OF_FILE, "MaxMessageSize reached");
}
}
};
/**

View File

@ -26,6 +26,7 @@ namespace thrift {
namespace transport {
uint32_t TPipedTransport::read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
uint32_t need = len;
// We don't have enough data yet
@ -104,8 +105,9 @@ void TPipedTransport::flush() {
TPipedFileReaderTransport::TPipedFileReaderTransport(
std::shared_ptr<TFileReaderTransport> srcTrans,
std::shared_ptr<TTransport> dstTrans)
: TPipedTransport(srcTrans, dstTrans), srcTrans_(srcTrans) {
std::shared_ptr<TTransport> dstTrans,
std::shared_ptr<TConfiguration> config)
: TPipedTransport(srcTrans, dstTrans, config), srcTrans_(srcTrans) {
}
TPipedFileReaderTransport::~TPipedFileReaderTransport() = default;
@ -131,6 +133,7 @@ uint32_t TPipedFileReaderTransport::read(uint8_t* buf, uint32_t len) {
}
uint32_t TPipedFileReaderTransport::readAll(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
uint32_t have = 0;
uint32_t get = 0;

View File

@ -63,8 +63,10 @@ public:
*/
class TPipedTransport : virtual public TTransport {
public:
TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans)
: srcTrans_(srcTrans),
TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans,
std::shared_ptr<TConfiguration> config = nullptr)
: TTransport(config),
srcTrans_(srcTrans),
dstTrans_(dstTrans),
rBufSize_(512),
rPos_(0),
@ -88,8 +90,10 @@ public:
TPipedTransport(std::shared_ptr<TTransport> srcTrans,
std::shared_ptr<TTransport> dstTrans,
uint32_t sz)
: srcTrans_(srcTrans),
uint32_t sz,
std::shared_ptr<TConfiguration> config = nullptr)
: TTransport(config),
srcTrans_(srcTrans),
dstTrans_(dstTrans),
rBufSize_(512),
rPos_(0),
@ -241,7 +245,8 @@ protected:
class TPipedFileReaderTransport : public TPipedTransport, public TFileReaderTransport {
public:
TPipedFileReaderTransport(std::shared_ptr<TFileReaderTransport> srcTrans,
std::shared_ptr<TTransport> dstTrans);
std::shared_ptr<TTransport> dstTrans,
std::shared_ptr<TConfiguration> config = nullptr);
~TPipedFileReaderTransport() override;

View File

@ -57,7 +57,7 @@ public:
void consume(uint32_t len) { this->TTransport::consume_virt(len); }
protected:
TTransportDefaults() = default;
TTransportDefaults(std::shared_ptr<TConfiguration> config = nullptr) : TTransport(config) {}
};
/**
@ -118,7 +118,7 @@ public:
}
protected:
TVirtualTransport() = default;
TVirtualTransport() : Super_() {}
/*
* Templatized constructors, to allow arguments to be passed to the Super_

View File

@ -53,8 +53,8 @@ std::string base64Encode(unsigned char* data, int length);
template <bool binary>
class TWebSocketServer : public THttpServer {
public:
TWebSocketServer(std::shared_ptr<TTransport> transport)
: THttpServer(transport) {
TWebSocketServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
: THttpServer(transport, config) {
resetHandshake();
}
@ -98,6 +98,7 @@ public:
}
void flush() override {
resetConsumedMessageSize();
writeFrameHeader();
uint8_t* buffer;
uint32_t length;

View File

@ -136,6 +136,7 @@ inline int TZlibTransport::readAvail() const {
}
uint32_t TZlibTransport::read(uint8_t* buf, uint32_t len) {
checkReadBytesAvailable(len);
uint32_t need = len;
// TODO(dreiss): Skip urbuf on big reads.
@ -265,6 +266,7 @@ void TZlibTransport::flush() {
}
flushToTransport(Z_FULL_FLUSH);
resetConsumedMessageSize();
}
void TZlibTransport::finish() {
@ -335,6 +337,7 @@ const uint8_t* TZlibTransport::borrow(uint8_t* buf, uint32_t* len) {
}
void TZlibTransport::consume(uint32_t len) {
countConsumedMessageBytes(len);
if (readAvail() >= (int)len) {
urpos_ += len;
} else {

View File

@ -83,8 +83,10 @@ public:
int crbuf_size = DEFAULT_CRBUF_SIZE,
int uwbuf_size = DEFAULT_UWBUF_SIZE,
int cwbuf_size = DEFAULT_CWBUF_SIZE,
int16_t comp_level = Z_DEFAULT_COMPRESSION)
: transport_(transport),
int16_t comp_level = Z_DEFAULT_COMPRESSION,
std::shared_ptr<TConfiguration> config = nullptr)
: TVirtualTransport(config),
transport_(transport),
urpos_(0),
uwpos_(0),
input_ended_(false),

View File

@ -81,6 +81,7 @@ set(UnitTest_SOURCES
TypedefTest.cpp
TServerSocketTest.cpp
TServerTransportTest.cpp
ThrifttReadCheckTests.cpp
)
add_executable(UnitTests ${UnitTest_SOURCES})

View File

@ -130,7 +130,8 @@ UnitTests_SOURCES = \
TypedefTest.cpp \
TServerSocketTest.cpp \
TServerTransportTest.cpp \
TTransportCheckThrow.h
TTransportCheckThrow.h \
ThrifttReadCheckTests.cpp
UnitTests_LDADD = \
libtestgencpp.la \

View File

@ -0,0 +1,227 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#define MAX_MESSAGE_SIZE 2
#include <boost/test/auto_unit_test.hpp>
#include <boost/test/unit_test.hpp>
#include <iostream>
#include <climits>
#include <vector>
#include <thrift/TConfiguration.h>
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/protocol/TCompactProtocol.h>
#include <thrift/protocol/TJSONProtocol.h>
#include <thrift/Thrift.h>
#include <memory>
#include <thrift/transport/TTransportUtils.h>
#include <thrift/transport/TBufferTransports.h>
#include <thrift/transport/TSimpleFileTransport.h>
#include <thrift/transport/TFileTransport.h>
#include <thrift/protocol/TEnum.h>
#include <thrift/protocol/TList.h>
#include <thrift/protocol/TSet.h>
#include <thrift/protocol/TMap.h>
BOOST_AUTO_TEST_SUITE(ThriftReadCheckExceptionTest)
using apache::thrift::TConfiguration;
using apache::thrift::protocol::TBinaryProtocol;
using apache::thrift::protocol::TCompactProtocol;
using apache::thrift::protocol::TJSONProtocol;
using apache::thrift::protocol::TType;
using apache::thrift::transport::TPipedTransport;
using apache::thrift::transport::TMemoryBuffer;
using apache::thrift::transport::TSimpleFileTransport;
using apache::thrift::transport::TFileTransport;
using apache::thrift::transport::TFDTransport;
using apache::thrift::transport::TTransportException;
using apache::thrift::transport::TBufferedTransport;
using apache::thrift::transport::TFramedTransport;
using std::shared_ptr;
using std::cout;
using std::endl;
using std::string;
using std::memset;
using namespace apache::thrift;
using namespace apache::thrift::protocol;
BOOST_AUTO_TEST_CASE(test_tmemorybuffer_read_check_exception) {
std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
TMemoryBuffer trans_out(config);
uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
trans_out.write((const uint8_t*)buffer, sizeof(buffer));
trans_out.close();
TMemoryBuffer trans_in(config);
memset(buffer, 0, sizeof(buffer));
BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
trans_in.close();
}
BOOST_AUTO_TEST_CASE(test_tpipedtransport_read_check_exception) {
std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
std::shared_ptr<TMemoryBuffer> pipe(new TMemoryBuffer);
std::shared_ptr<TMemoryBuffer> underlying(new TMemoryBuffer);
std::shared_ptr<TPipedTransport> trans(new TPipedTransport(underlying, pipe, config));
uint8_t buffer[4];
underlying->write((uint8_t*)"abcd", 4);
BOOST_CHECK_THROW(trans->read(buffer, sizeof(buffer)), TTransportException*);
BOOST_CHECK_THROW(trans->readAll(buffer, sizeof(buffer)), TTransportException*);
trans->readEnd();
pipe->resetBuffer();
underlying->write((uint8_t*)"ef", 2);
BOOST_CHECK_THROW(trans->read(buffer, sizeof(buffer)), TTransportException*);
BOOST_CHECK_THROW(trans->readAll(buffer, sizeof(buffer)), TTransportException*);
trans->readEnd();
}
BOOST_AUTO_TEST_CASE(test_tsimplefiletransport_read_check_exception) {
std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
TSimpleFileTransport trans_out("data", false, true, config);
uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
trans_out.write((const uint8_t*)buffer, sizeof(buffer));
trans_out.close();
TSimpleFileTransport trans_in("data",true, false, config);
memset(buffer, 0, sizeof(buffer));
BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
trans_in.close();
remove("./data");
}
BOOST_AUTO_TEST_CASE(test_tfiletransport_read_check_exception) {
std::shared_ptr<TConfiguration> config(new TConfiguration(MAX_MESSAGE_SIZE));
TFileTransport trans_out("data", false, config);
uint8_t buffer[6] = {1, 2, 3, 4, 5, 6};
trans_out.write((const uint8_t*)buffer, sizeof(buffer));
TFileTransport trans_in("data", false, config);
memset(buffer, 0, sizeof(buffer));
BOOST_CHECK_THROW(trans_in.read(buffer, sizeof(buffer)), TTransportException*);
remove("./data");
}
BOOST_AUTO_TEST_CASE(test_tbufferedtransport_read_check_exception) {
uint8_t arr[4] = {1, 2, 3, 4};
std::shared_ptr<TMemoryBuffer> buffer (new TMemoryBuffer(arr, sizeof(arr)));
std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
std::shared_ptr<TBufferedTransport> trans (new TBufferedTransport(buffer, config));
trans->write((const uint8_t*)arr, sizeof(arr));
BOOST_CHECK_THROW(trans->read(arr, sizeof(arr)), TTransportException*);
}
BOOST_AUTO_TEST_CASE(test_tframedtransport_read_check_exception) {
uint8_t arr[4] = {1, 2, 3, 4};
std::shared_ptr<TMemoryBuffer> buffer (new TMemoryBuffer(arr, sizeof(arr)));
std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
std::shared_ptr<TFramedTransport> trans (new TFramedTransport(buffer, config));
trans->write((const uint8_t*)arr, sizeof(arr));
BOOST_CHECK_THROW(trans->read(arr, sizeof(arr)), TTransportException*);
}
BOOST_AUTO_TEST_CASE(test_tthriftbinaryprotocol_read_check_exception) {
std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
std::shared_ptr<TBinaryProtocol> protocol(new TBinaryProtocol(transport));
uint32_t val = 0;
TType elemType = apache::thrift::protocol::T_STOP;
TType elemType1 = apache::thrift::protocol::T_STOP;
TList list(T_I32, 8);
protocol->writeListBegin(list.elemType_, list.size_);
protocol->writeListEnd();
BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
protocol->readListEnd();
TSet set(T_I32, 8);
protocol->writeSetBegin(set.elemType_, set.size_);
protocol->writeSetEnd();
BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
protocol->readSetEnd();
TMap map(T_I32, T_I32, 8);
protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
protocol->writeMapEnd();
BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
protocol->readMapEnd();
}
BOOST_AUTO_TEST_CASE(test_tthriftcompactprotocol_read_check_exception) {
std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
std::shared_ptr<TCompactProtocol> protocol(new TCompactProtocol(transport));
uint32_t val = 0;
TType elemType = apache::thrift::protocol::T_STOP;
TType elemType1 = apache::thrift::protocol::T_STOP;
TList list(T_I32, 8);
protocol->writeListBegin(list.elemType_, list.size_);
protocol->writeListEnd();
BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
protocol->readListEnd();
TSet set(T_I32, 8);
protocol->writeSetBegin(set.elemType_, set.size_);
protocol->writeSetEnd();
BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
protocol->readSetEnd();
TMap map(T_I32, T_I32, 8);
protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
protocol->writeMapEnd();
BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
protocol->readMapEnd();
}
BOOST_AUTO_TEST_CASE(test_tthriftjsonprotocol_read_check_exception) {
std::shared_ptr<TConfiguration> config (new TConfiguration(MAX_MESSAGE_SIZE));
std::shared_ptr<TMemoryBuffer> transport(new TMemoryBuffer(config));
std::shared_ptr<TJSONProtocol> protocol(new TJSONProtocol(transport));
uint32_t val = 0;
TType elemType = apache::thrift::protocol::T_STOP;
TType elemType1 = apache::thrift::protocol::T_STOP;
TList list(T_I32, 8);
protocol->writeListBegin(list.elemType_, list.size_);
protocol->writeListEnd();
BOOST_CHECK_THROW(protocol->readListBegin(elemType, val), TTransportException*);
protocol->readListEnd();
TSet set(T_I32, 8);
protocol->writeSetBegin(set.elemType_, set.size_);
protocol->writeSetEnd();
BOOST_CHECK_THROW(protocol->readSetBegin(elemType, val), TTransportException*);
protocol->readSetEnd();
TMap map(T_I32, T_I32, 8);
protocol->writeMapBegin(map.keyType_, map.valueType_, map.size_);
protocol->writeMapEnd();
BOOST_CHECK_THROW(protocol->readMapBegin(elemType, elemType1, val), TTransportException*);
protocol->readMapEnd();
}
BOOST_AUTO_TEST_SUITE_END()