mirror of
https://github.com/valitydev/thrift.git
synced 2024-11-06 18:35:19 +00:00
10308cb975
This closes #832
109 lines
3.9 KiB
Python
109 lines
3.9 KiB
Python
#!/usr/bin/env python
|
|
|
|
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
#
|
|
|
|
from ThriftTest.ttypes import Bonk, VersioningTestV1, VersioningTestV2
|
|
from thrift.protocol import TJSONProtocol
|
|
from thrift.transport import TTransport
|
|
|
|
import json
|
|
import unittest
|
|
|
|
|
|
class SimpleJSONProtocolTest(unittest.TestCase):
|
|
protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()
|
|
|
|
def _assertDictEqual(self, a, b, msg=None):
|
|
if hasattr(self, 'assertDictEqual'):
|
|
# assertDictEqual only in Python 2.7. Depends on your machine.
|
|
self.assertDictEqual(a, b, msg)
|
|
return
|
|
|
|
# Substitute implementation not as good as unittest library's
|
|
self.assertEquals(len(a), len(b), msg)
|
|
for k, v in a.iteritems():
|
|
self.assertTrue(k in b, msg)
|
|
self.assertEquals(b.get(k), v, msg)
|
|
|
|
def _serialize(self, obj):
|
|
trans = TTransport.TMemoryBuffer()
|
|
prot = self.protocol_factory.getProtocol(trans)
|
|
obj.write(prot)
|
|
return trans.getvalue()
|
|
|
|
def _deserialize(self, objtype, data):
|
|
prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
|
|
ret = objtype()
|
|
ret.read(prot)
|
|
return ret
|
|
|
|
def testWriteOnly(self):
|
|
self.assertRaises(NotImplementedError,
|
|
self._deserialize, VersioningTestV1, b'{}')
|
|
|
|
def testSimpleMessage(self):
|
|
v1obj = VersioningTestV1(
|
|
begin_in_both=12345,
|
|
old_string='aaa',
|
|
end_in_both=54321)
|
|
expected = dict(begin_in_both=v1obj.begin_in_both,
|
|
old_string=v1obj.old_string,
|
|
end_in_both=v1obj.end_in_both)
|
|
actual = json.loads(self._serialize(v1obj).decode('ascii'))
|
|
|
|
self._assertDictEqual(expected, actual)
|
|
|
|
def testComplicated(self):
|
|
v2obj = VersioningTestV2(
|
|
begin_in_both=12345,
|
|
newint=1,
|
|
newbyte=2,
|
|
newshort=3,
|
|
newlong=4,
|
|
newdouble=5.0,
|
|
newstruct=Bonk(message="Hello!", type=123),
|
|
newlist=[7, 8, 9],
|
|
newset=set([42, 1, 8]),
|
|
newmap={1: 2, 2: 3},
|
|
newstring="Hola!",
|
|
end_in_both=54321)
|
|
expected = dict(begin_in_both=v2obj.begin_in_both,
|
|
newint=v2obj.newint,
|
|
newbyte=v2obj.newbyte,
|
|
newshort=v2obj.newshort,
|
|
newlong=v2obj.newlong,
|
|
newdouble=v2obj.newdouble,
|
|
newstruct=dict(message=v2obj.newstruct.message,
|
|
type=v2obj.newstruct.type),
|
|
newlist=v2obj.newlist,
|
|
newset=list(v2obj.newset),
|
|
newmap=v2obj.newmap,
|
|
newstring=v2obj.newstring,
|
|
end_in_both=v2obj.end_in_both)
|
|
|
|
# Need to load/dump because map keys get escaped.
|
|
expected = json.loads(json.dumps(expected))
|
|
actual = json.loads(self._serialize(v2obj).decode('ascii'))
|
|
self._assertDictEqual(expected, actual)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|