Use with statements when working with files in states.file

This commit is contained in:
Dan Colish 2012-03-04 09:28:43 -08:00
parent 4fa502bc59
commit 1bdd2bad6b

View File

@ -82,6 +82,7 @@ something like this:
'''
import codecs
from contextlib import nested # For < 2.7 compat
import os
import shutil
import difflib
@ -208,8 +209,9 @@ def _mako(sfn, name, source, user, group, mode, env, context=None):
passthrough = context if context else {}
passthrough.update(__salt__)
passthrough.update(__grains__)
template = Template(open(sfn, 'r').read())
open(tgt, 'w+').write(template.render(**passthrough))
with nested(open(sfn, 'r'), open(tgt, 'w+')) as (src, target):
template = Template(src.read())
target.write(template.render(**passthrough))
return {'result': True,
'data': tgt}
except:
@ -234,8 +236,9 @@ def _jinja(sfn, name, source, user, group, mode, env, context=None):
'data': 'Failed to import jinja'}
try:
newline = False
if open(sfn, 'rb').read().endswith('\n'):
newline = True
with open(sfn, 'rb') as source:
if source.read().endswith('\n'):
newline = True
tgt = tempfile.mkstemp()[1]
passthrough = context if context else {}
passthrough['salt'] = __salt__
@ -248,11 +251,15 @@ def _jinja(sfn, name, source, user, group, mode, env, context=None):
passthrough['env'] = env
template = get_template(sfn, __opts__, env)
try:
open(tgt, 'w+').write(template.render(**passthrough))
with open(tgt, 'w+') as target:
target.write(template.render(**passthrough))
if newline:
target.write('\n')
except UnicodeEncodeError:
codecs.open(tgt, encoding='utf-8', mode='w+').write(template.render(**passthrough))
if newline:
open(tgt, 'a').write('\n')
with codecs.open(tgt, encoding='utf-8', mode='w+') as target:
target.write(template.render(**passthrough))
if newline:
target.write('\n')
return {'result': True,
'data': tgt}
except:
@ -289,7 +296,8 @@ def _py(sfn, name, source, user, group, mode, env, context=None):
try:
tgt = tempfile.mkstemp()[1]
open(tgt, 'w+').write(mod.run())
with open(tgt, 'w+') as target:
target.write(mod.run())
return {'result': True,
'data': tgt}
except:
@ -543,7 +551,9 @@ def managed(name,
return ret
if data['result']:
sfn = data['data']
hsum = hashlib.md5(open(sfn, 'r').read()).hexdigest()
hsum = ''
with open(sfn, 'r') as source:
hsum = hashlib.md5(source.read()).hexdigest()
source_sum = {'hash_type': 'md5',
'hsum': hsum}
else:
@ -571,7 +581,9 @@ def managed(name,
source_hash
)
return ret
comps = open(hash_fn, 'r').read().split('=')
comps = []
with open(hash_fn, 'r') as hashfile:
comps = hashfile.read().split('=')
if len(comps) < 2:
ret['result'] = False
ret['comment'] = ('Source hash file {0} contains an '
@ -650,8 +662,10 @@ def managed(name,
# Only test the checksums on files with managed contents
if source:
name_sum = getattr(hashlib, source_sum['hash_type'])(open(name,
'rb').read()).hexdigest()
name_sum = ''
hash_func = getattr(hashlib, source_sum['hash_type'])
with open(name, 'rb') as namefile:
name_sum = hash_func(namefile.read()).hexdigest()
# Check if file needs to be replaced
if source and source_sum['hsum'] != name_sum:
@ -665,12 +679,13 @@ def managed(name,
if _is_bin(sfn) or _is_bin(name):
ret['changes']['diff'] = 'Replace binary file'
else:
slines = open(sfn, 'rb').readlines()
nlines = open(name, 'rb').readlines()
with nested(open(sfn, 'rb'), open(name, 'rb')) as (src, name_):
slines = src.readlines()
nlines = name_.readlines()
# Print a diff equivalent to diff -u old new
ret['changes']['diff'] = (''.join(difflib
.unified_diff(nlines,
slines)))
ret['changes']['diff'] = (''.join(difflib
.unified_diff(nlines,
slines)))
# Pre requisites are met, and the file needs to be replaced, do it
if not __opts__['test']:
shutil.copyfile(sfn, name)
@ -717,7 +732,6 @@ def managed(name,
# Create the file, user-rw-only if mode will be set
if mode:
cumask = os.umask(384)
open(name, 'a+').close()
if mode:
os.umask(cumask)
ret['changes']['new'] = 'file {0} created'.format(name)
@ -1049,9 +1063,12 @@ def recurse(name,
_makedirs(dest)
if os.path.isfile(dest):
keep.add(dest)
srch = ''
dsth = ''
# The file is present, if the sum differes replace it
srch = hashlib.md5(open(fn_, 'r').read()).hexdigest()
dsth = hashlib.md5(open(dest, 'r').read()).hexdigest()
with nested(open(fn_, 'r'), open(dest, 'r')) as (src_, dst_):
srch = hashlib.md5(src_.read()).hexdigest()
dsth = hashlib.md5(dst_.read()).hexdigest()
if srch != dsth:
# The downloaded file differes, replace!
# FIXME: no metadata (ownership, permissions) available