Fix symlink recursion in salt.utils.safe_walk

This commit is contained in:
Daniel Miller 2013-05-02 13:32:21 -04:00
parent 8370703860
commit b5d130cc3a
2 changed files with 80 additions and 28 deletions

View File

@ -1033,38 +1033,53 @@ def parse_docstring(docstring):
return ret
def safe_walk(dir_):
def safe_walk(top, topdown=True, onerror=None, followlinks=True, _seen=set()):
'''
A clone of the python os.walk function with some checks for recursive
symlinks. This functions the same way as the os.walk function but with
fewer options.
symlinks. Unlike os.walk this follows symlinks by default.
'''
dirs = []
files = []
for fn_ in os.listdir(dir_):
full = os.path.join(dir_, fn_)
try:
mode = os.lstat(full).st_mode
except os.error:
continue
if stat.S_ISLNK(mode):
real = os.path.realpath(full)
try:
smode = os.stat(real).st_mode
if stat.S_ISDIR(smode):
if not full.startswith(real):
mode = smode
except os.error:
pass
islink, join, isdir = os.path.islink, os.path.join, os.path.isdir
if stat.S_ISDIR(mode):
dirs.append(full)
elif stat.S_ISREG(mode):
files.append(full)
yield dir_, dirs, files
for sdir in dirs:
for ret in safe_walk(sdir):
yield ret
# We may not have read permission for top, in which case we can't
# get a list of the files the directory contains. os.path.walk
# always suppressed the exception then, rather than blow up for a
# minor reason when (say) a thousand readable directories are still
# left to visit. That logic is copied here.
try:
# Note that listdir and error are globals in this module due
# to earlier import-*.
names = os.listdir(top)
except os.error, err:
if onerror is not None:
onerror(err)
return
if followlinks:
stat = os.stat(top)
# st_ino is always 0 on some filesystems (FAT, NTFS); ignore them
if stat.st_ino != 0:
node = (stat.st_dev, stat.st_ino)
if node in _seen:
return
_seen.add(node)
dirs, nondirs = [], []
for name in names:
full_path = join(top, name)
if isdir(full_path):
dirs.append(name)
else:
nondirs.append(name)
if topdown:
yield top, dirs, nondirs
for name in dirs:
new_path = join(top, name)
if followlinks or not islink(new_path):
for x in safe_walk(new_path, topdown, onerror, followlinks, _seen):
yield x
if not topdown:
yield top, dirs, nondirs
def get_hash(path, form='md5', chunk_size=4096):

View File

@ -0,0 +1,37 @@
import os
from os.path import join
from shutil import rmtree
from tempfile import mkdtemp
# Import salt libs
import salt.utils
import salt.utils.find
from saltunittest import TestCase, TestLoader, TextTestRunner, skipIf
class TestUtils(TestCase):
def test_safe_walk_symlink_recursion(self):
tmp = mkdtemp()
try:
if os.stat(tmp).st_ino == 0:
self.skipTest("inodes not supported in {}".format(tmp))
os.mkdir(join(tmp, "fax"))
os.makedirs(join(tmp, "foo/bar"))
os.symlink("../..", join(tmp, "foo/bar/baz"))
os.symlink("foo", join(tmp, "root"))
expected = [
(join(tmp, 'root'), ['bar'], []),
(join(tmp, 'root/bar'), ['baz'], []),
(join(tmp, 'root/bar/baz'), ['fax', 'foo', 'root'], []),
(join(tmp, 'root/bar/baz/fax'), [], []),
]
paths = []
for root, dirs, names in salt.utils.safe_walk(join(tmp, "root")):
paths.append((root, dirs, names))
if paths != expected:
raise AssertionError("\n".join(["got:"]
+ [repr(p) for p in paths]
+ ["", "expected:"] + [repr(p) for p in expected]))
finally:
rmtree(tmp)