# -*- coding: utf-8 -*- ''' :codeauthor: :email:`Mike Place ` ''' # Import python libs from __future__ import absolute_import import os import logging import tornado.gen import tornado.ioloop import tornado.testing import salt.utils import salt.config import salt.exceptions import salt.transport.ipc import salt.transport.server import salt.transport.client import salt.ext.six as six from salt.ext.six.moves import range # Import Salt Testing libs import integration from salttesting.mock import MagicMock from salttesting.helpers import ensure_in_syspath log = logging.getLogger(__name__) ensure_in_syspath('../') class BaseIPCReqCase(tornado.testing.AsyncTestCase): ''' Test the req server/client pair ''' def setUp(self): super(BaseIPCReqCase, self).setUp() self._start_handlers = dict(self.io_loop._handlers) self.socket_path = os.path.join(integration.TMP, 'ipc_test.ipc') self.server_channel = salt.transport.ipc.IPCMessageServer( self.socket_path, io_loop=self.io_loop, payload_handler=self._handle_payload, ) self.server_channel.start() self.payloads = [] def tearDown(self): super(BaseIPCReqCase, self).tearDown() failures = [] self.server_channel.close() os.unlink(self.socket_path) for k, v in six.iteritems(self.io_loop._handlers): if self._start_handlers.get(k) != v: failures.append((k, v)) if len(failures) > 0: raise Exception('FDs still attached to the IOLoop: {0}'.format(failures)) @tornado.gen.coroutine def _handle_payload(self, payload, reply_func): self.payloads.append(payload) yield reply_func(payload) if isinstance(payload, dict) and payload.get('stop'): self.stop() class IPCMessageClient(BaseIPCReqCase): ''' Test all of the clear msg stuff ''' def _get_channel(self): channel = salt.transport.ipc.IPCMessageClient( socket_path=self.socket_path, io_loop=self.io_loop, ) channel.connect(callback=self.stop) self.wait() return channel def setUp(self): super(IPCMessageClient, self).setUp() self.channel = self._get_channel() def tearDown(self): super(IPCMessageClient, self).setUp() self.channel.close() def test_basic_send(self): msg = {'foo': 'bar', 'stop': True} self.channel.send(msg) self.wait() self.assertEqual(self.payloads[0], msg) def test_many_send(self): msgs = [] self.server_channel.stream_handler = MagicMock() for i in range(0, 1000): msgs.append('test_many_send_{0}'.format(i)) for i in msgs: self.channel.send(i) self.channel.send({'stop': True}) self.wait() self.assertEqual(self.payloads[:-1], msgs) def test_very_big_message(self): long_str = ''.join([str(num) for num in range(10**5)]) msg = {'long_str': long_str, 'stop': True} self.channel.send(msg) self.wait() self.assertEqual(msg, self.payloads[0]) def test_multistream_sends(self): local_channel = self._get_channel() for c in (self.channel, local_channel): c.send('foo') self.channel.send({'stop': True}) self.wait() self.assertEqual(self.payloads[:-1], ['foo', 'foo']) def test_multistream_errors(self): local_channel = self._get_channel() for c in (self.channel, local_channel): c.send(None) for c in (self.channel, local_channel): c.send('foo') self.channel.send({'stop': True}) self.wait() self.assertEqual(self.payloads[:-1], [None, None, 'foo', 'foo']) if __name__ == '__main__': from integration import run_tests run_tests(IPCMessageClient, needs_daemon=False)