Reject multiplexed thrift requests by the server

This commit is contained in:
Anton Belyaev 2016-04-14 20:55:47 +03:00
parent 679beb2d59
commit d4c5f5edbc
2 changed files with 57 additions and 9 deletions

View File

@ -34,6 +34,10 @@
] ++ Meta) ] ++ Meta)
). ).
-define(error_unknown_function, no_function).
-define(error_multiplexed_req, multiplexed_request).
-define(error_protocol_send, send_error).
-record(state, { -record(state, {
req_id :: rpc_t:req_id(), req_id :: rpc_t:req_id(),
rpc_client :: rpc_client:client(), rpc_client :: rpc_client:client(),
@ -87,9 +91,9 @@ process(State = #state{protocol = Protocol, service = Service}) ->
Type =:= ?tMessageType_ONEWAY Type =:= ?tMessageType_ONEWAY
-> ->
State2 = release_oneway(Type, State1), State2 = release_oneway(Type, State1),
FunctionName = list_to_existing_atom(Function), FunctionName = get_function_name(Function),
prepare_response(handle_function(FunctionName, prepare_response(handle_function(FunctionName,
Service:function_info(FunctionName, params_type), get_params_type(Service, FunctionName),
State2, State2,
SeqId SeqId
), FunctionName); ), FunctionName);
@ -97,14 +101,34 @@ process(State = #state{protocol = Protocol, service = Service}) ->
handle_protocol_error(State1, undefined, Reason) handle_protocol_error(State1, undefined, Reason)
end. end.
get_function_name(Function) ->
case string:tokens(Function, ?MULTIPLEXED_SERVICE_SEPARATOR) of
[_ServiceName, _FunctionName] ->
{error, ?error_multiplexed_req};
_ ->
try list_to_existing_atom(Function)
catch
error:badarg -> {error, ?error_unknown_function}
end
end.
get_params_type(Service, Function) ->
try Service:function_info(Function, params_type)
catch
error:badarg -> ?error_unknown_function
end.
release_oneway(?tMessageType_ONEWAY, State = #state{protocol = Protocol}) -> release_oneway(?tMessageType_ONEWAY, State = #state{protocol = Protocol}) ->
{Protocol1, ok} = thrift_protocol:flush_transport(Protocol), {Protocol1, ok} = thrift_protocol:flush_transport(Protocol),
State#state{protocol = Protocol1}; State#state{protocol = Protocol1};
release_oneway(_, State) -> release_oneway(_, State) ->
State. State.
handle_function(_, no_function, State, _SeqId) -> handle_function(Error = {error, _}, _, State, _SeqId) ->
{State, {error, function_undefined}}; {State, Error};
handle_function(_, ?error_unknown_function, State, _SeqId) ->
{State, {error, ?error_unknown_function}};
handle_function(Function, InParams, State = #state{protocol = Protocol}, SeqId) -> handle_function(Function, InParams, State = #state{protocol = Protocol}, SeqId) ->
{Protocol1, ReadResult} = thrift_protocol:read(Protocol, InParams), {Protocol1, ReadResult} = thrift_protocol:read(Protocol, InParams),
@ -246,7 +270,7 @@ send_reply(State = #state{protocol = Protocol}, Function, ReplyMessageType, Repl
{State#state{protocol = Protocol4}, ok} {State#state{protocol = Protocol4}, ok}
catch catch
error:{badmatch, {_, {error, _} = Error}} -> error:{badmatch, {_, {error, _} = Error}} ->
{State, {error, {send_error, [Error, erlang:get_stacktrace()]}}} {State, {error, {?error_protocol_send, [Error, erlang:get_stacktrace()]}}}
end. end.
prepare_response({State, ok}, _) -> prepare_response({State, ok}, _) ->
@ -284,10 +308,13 @@ format_protocol_error({bad_binary_protocol_version, _Version}, Trans) ->
format_protocol_error(no_binary_protocol_version, Trans) -> format_protocol_error(no_binary_protocol_version, Trans) ->
mark_error_to_transport(Trans, transport, "no binary protocol version"), mark_error_to_transport(Trans, transport, "no binary protocol version"),
{error, badrequest}; {error, badrequest};
format_protocol_error({function_undefined, _Fun}, Trans) -> format_protocol_error({?error_unknown_function, _Fun}, Trans) ->
mark_error_to_transport(Trans, transport, "unknown method"), mark_error_to_transport(Trans, transport, "unknown method"),
{error, badrequest}; {error, badrequest};
format_protocol_error({send_error, _}, Trans) -> format_protocol_error({?error_multiplexed_req, _Fun}, Trans) ->
mark_error_to_transport(Trans, transport, "multiplexing not supported"),
{error, badrequest};
format_protocol_error({?error_protocol_send, _}, Trans) ->
mark_error_to_transport(Trans, transport, "internal error"), mark_error_to_transport(Trans, transport, "internal error"),
{error, server_error}; {error, server_error};
format_protocol_error(_Reason, Trans) -> format_protocol_error(_Reason, Trans) ->

View File

@ -107,7 +107,8 @@ all() ->
call_async_ok_test, call_async_ok_test,
checkrpc_ids_sequence_test, checkrpc_ids_sequence_test,
call_two_services_test, call_two_services_test,
call_with_client_pool_test call_with_client_pool_test,
multiplexed_transport_test
]. ].
%% %%
@ -134,7 +135,8 @@ init_per_testcase(Tc, C) when
Tc =:= call_safe_server_transport_error_test ; Tc =:= call_safe_server_transport_error_test ;
Tc =:= call_server_transport_error_test ; Tc =:= call_server_transport_error_test ;
Tc =:= call_handle_error_fails_test ; Tc =:= call_handle_error_fails_test ;
Tc =:= call_oneway_void_test Tc =:= call_oneway_void_test ;
Tc =:= multiplexed_transport_test
-> ->
do_init_per_testcase([powerups], C); do_init_per_testcase([powerups], C);
init_per_testcase(Tc, C) when init_per_testcase(Tc, C) when
@ -325,6 +327,25 @@ call_with_client_pool_test(_) ->
receive_msg({Id, Gun}), receive_msg({Id, Gun}),
ok = rpc_thrift_client:stop_pool(Pool). ok = rpc_thrift_client:stop_pool(Pool).
multiplexed_transport_test(_) ->
Id = <<"multiplexed_transport">>,
{Client1, {error, {400, _}}} = thrift_client:call(
make_thrift_multiplexed_client(Id, "powerups", get_service_endpoint(powerups)),
get_powerup,
[<<"Body Armor">>, self_to_bin()]
),
thrift_client:close(Client1).
make_thrift_multiplexed_client(Id, ServiceName, {Url, Service}) ->
{ok, Protocol} = thrift_binary_protocol:new(
rpc_thrift_http_transport:new(true, Id, Id, #{url => Url}),
[{strict_read, true}, {strict_write, true}]
),
{ok, Protocol1} = thrift_multiplexed_protocol:new(Protocol, ServiceName),
{ok, Client} = thrift_client:new(Protocol1, Service),
Client.
%% %%
%% supervisor callbacks %% supervisor callbacks
%% %%