add tests for modules/hosts
This commit is contained in:
Martin Schnabel 2011-12-31 00:21:08 +01:00
parent 6c42208602
commit 5c0a055fb7
3 changed files with 127 additions and 8 deletions

View File

@ -15,7 +15,7 @@ def list_hosts():
salt '*' hosts.list_hosts
'''
hfn = '/etc/hosts'
hfn = list_hosts.hosts_filename
ret = {}
if not os.path.isfile(hfn):
return ret
@ -26,8 +26,13 @@ def list_hosts():
if line.startswith('#'):
continue
comps = line.split()
ret[comps[0]] = comps[1:]
if comps[0] in ret:
# maybe log a warning ?
ret[comps[0]].extend(comps[1:])
else:
ret[comps[0]] = comps[1:]
return ret
list_hosts.hosts_filename = '/etc/hosts'
def get_ip(host):
@ -78,13 +83,13 @@ def has_pair(ip, alias):
def set_host(ip, alias):
'''
Set the host entry in th hosts file for the given ip, this will overwrite
Set the host entry in the hosts file for the given ip, this will overwrite
any previous entry for the given ip
CLI Example::
salt '*' hosts.set_host <ip> <alias>
'''
hfn = '/etc/hosts'
hfn = set_host.hosts_filename
ovr = False
if not os.path.isfile(hfn):
return False
@ -97,13 +102,20 @@ def set_host(ip, alias):
continue
comps = tmpline.split()
if comps[0] == ip:
lines[ind] = ip + '\t\t' + alias + '\n'
ovr = True
if not ovr:
lines[ind] = ip + '\t\t' + alias + '\n'
ovr = True
else: # remove other entries
lines[ind] = ''
if not ovr:
# make sure there is a newline
if lines and not lines[-1].endswith(('\n', '\r')):
lines[-1] = '%s\n' % lines[-1]
line = ip + '\t\t' + alias + '\n'
lines.append(line)
open(hfn, 'w+').writelines(lines)
return True
set_host.hosts_filename = '/etc/hosts'
def rm_host(ip, alias):
@ -115,7 +127,7 @@ def rm_host(ip, alias):
'''
if not has_pair(ip, alias):
return True
hfn = '/etc/hosts'
hfn = rm_host.hosts_filename
lines = open(hfn).readlines()
for ind in range(len(lines)):
tmpline = lines[ind].strip()
@ -136,6 +148,7 @@ def rm_host(ip, alias):
lines[ind] = newline
open(hfn, 'w+').writelines(lines)
return True
rm_host.hosts_filename = '/etc/hosts'
def add_host(ip, alias):
@ -146,7 +159,7 @@ def add_host(ip, alias):
CLI Example::
salt '*' hosts.add_host <ip> <alias>
'''
hfn = '/etc/hosts'
hfn = add_host.hosts_filename
ovr = False
if not os.path.isfile(hfn):
return False
@ -165,8 +178,14 @@ def add_host(ip, alias):
newline += '\t' + alias
lines.append(newline)
ovr = True
# leave any other matching entries alone
break
if not ovr:
# make sure there is a newline
if lines and not lines[-1].endswith(('\n', '\r')):
lines[-1] = '%s\n' % lines[-1]
line = ip + '\t\t' + alias + '\n'
lines.append(line)
open(hfn, 'w+').writelines(lines)
return True
add_host.hosts_filename = '/etc/hosts'

11
tests/modules/files/hosts Normal file
View File

@ -0,0 +1,11 @@
# a comment
127.0.0.1 localhost
# second alias for same ip. 'man hosts' does not allow it
# but the ubuntu and probably other distributions treat it as
# 127.0.0.1 localhost myname
127.0.0.1 myname
::1 ip6-localhost ip6-loopback
fe00::0 ip6-localnet
ff00::0 ip6-mcastprefix
ff02::1 ip6-allnodes
ff02::2 ip6-allrouters

89
tests/modules/hosts.py Normal file
View File

@ -0,0 +1,89 @@
import unittest
from salt.modules.hosts import list_hosts, get_ip, get_alias, has_pair, add_host,\
set_host, rm_host
from os import path
import os
import shutil
TEMPLATES_DIR = path.dirname(path.abspath(__file__))
monkey_pathed = (list_hosts, set_host, add_host, rm_host)
class HostsModuleTest(unittest.TestCase):
def setUp(self):
self._hfn = [f.hosts_filename for f in monkey_pathed]
self.files = path.join(TEMPLATES_DIR, 'files')
self.hostspath = path.join(self.files, 'hosts')
self.not_found = path.join(self.files, 'not_found')
self.tmpfiles = []
def tearDown(self):
for i, f in enumerate(monkey_pathed):
f.hosts_filename = self._hfn[i]
for tmp in self.tmpfiles:
os.remove(tmp)
def tmp_hosts_file(self, src):
tmpfile = path.join(self.files, 'tmp')
self.tmpfiles.append(tmpfile)
shutil.copy(src, tmpfile)
return tmpfile
def test_list_hosts(self):
list_hosts.hosts_filename = self.hostspath
hosts = list_hosts()
self.assertEqual(len(hosts), 6)
self.assertEqual(hosts['::1'], ['ip6-localhost', 'ip6-loopback'])
self.assertEqual(hosts['127.0.0.1'], ['localhost', 'myname'])
def test_list_hosts_nofile(self):
list_hosts.hosts_filename = self.not_found
hosts = list_hosts()
self.assertEqual(hosts, {})
def test_get_ip(self):
list_hosts.hosts_filename = self.hostspath
self.assertEqual(get_ip('myname'), '127.0.0.1')
self.assertEqual(get_ip('othername'), '')
list_hosts.hosts_filename = self.not_found
self.assertEqual(get_ip('othername'), '')
def test_get_alias(self):
list_hosts.hosts_filename = self.hostspath
self.assertEqual(get_alias('127.0.0.1'), ['localhost', 'myname'])
self.assertEqual(get_alias('127.0.0.2'), [])
list_hosts.hosts_filename = self.not_found
self.assertEqual(get_alias('127.0.0.1'), [])
def test_has_pair(self):
list_hosts.hosts_filename = self.hostspath
self.assertTrue(has_pair('127.0.0.1', 'myname'))
self.assertFalse(has_pair('127.0.0.1', 'othername'))
def test_set_host(self):
tmp = self.tmp_hosts_file(self.hostspath)
list_hosts.hosts_filename = tmp
set_host.hosts_filename = tmp
assert set_host('192.168.1.123', 'newip')
self.assertTrue(has_pair('192.168.1.123', 'newip'))
self.assertEqual(len(list_hosts()), 7)
assert set_host('127.0.0.1', 'localhost')
self.assertFalse(has_pair('127.0.0.1', 'myname'), 'should remove second entry')
def test_add_host(self):
tmp = self.tmp_hosts_file(self.hostspath)
list_hosts.hosts_filename = tmp
add_host.hosts_filename = tmp
assert add_host('192.168.1.123', 'newip')
self.assertTrue(has_pair('192.168.1.123', 'newip'))
self.assertEqual(len(list_hosts()), 7)
assert add_host('127.0.0.1', 'othernameip')
self.assertEqual(len(list_hosts()), 7)
def test_rm_host(self):
tmp = self.tmp_hosts_file(self.hostspath)
list_hosts.hosts_filename = tmp
rm_host.hosts_filename = tmp
assert has_pair('127.0.0.1', 'myname')
assert rm_host('127.0.0.1', 'myname')
assert not has_pair('127.0.0.1', 'myname')
assert rm_host('127.0.0.1', 'unknown')