Forgot to add support for missing bytestring replacements

This commit is contained in:
Kovid Goyal 2016-06-16 11:13:44 +05:30
parent c9a8bd9884
commit 46f9c7c780
2 changed files with 19 additions and 11 deletions

View File

@ -276,24 +276,28 @@ class LocalZipFile(object):
replacements = {name:datastream}
replacements.update(extra_replacements)
names = frozenset(replacements.keys())
found = set([])
found = set()
def rbytes(name):
r = replacements[name]
if not isinstance(r, bytes):
r = r.read()
return r
with SpooledTemporaryFile(max_size=100*1024*1024) as temp:
ztemp = ZipFile(temp, 'w')
for offset, header in self.file_info.itervalues():
if header.filename in names:
zi = ZipInfo(header.filename)
zi.compress_type = header.compression_method
r = replacements[header.filename]
if not isinstance(r, bytes):
r = r.read()
ztemp.writestr(zi, r)
ztemp.writestr(zi, rbytes(header.filename))
found.add(header.filename)
else:
ztemp.writestr(header.filename, self.read(header.filename,
spool_size=0))
if add_missing:
for name in names - found:
ztemp.writestr(name, replacements[name].read())
ztemp.writestr(name, rbytes(name))
ztemp.close()
zipstream = self.stream
temp.seek(0)

View File

@ -1470,22 +1470,26 @@ def safe_replace(zipstream, name, datastream, extra_replacements={},
replacements.update(extra_replacements)
names = frozenset(replacements.keys())
found = set([])
def rbytes(name):
r = replacements[name]
if not isinstance(r, bytes):
r = r.read()
return r
with SpooledTemporaryFile(max_size=100*1024*1024) as temp:
ztemp = ZipFile(temp, 'w')
for obj in z.infolist():
if isinstance(obj.filename, unicode):
obj.flag_bits |= 0x16 # Set isUTF-8 bit
if obj.filename in names:
r = replacements[obj.filename]
if not isinstance(r, bytes):
r = r.read()
ztemp.writestr(obj, r)
ztemp.writestr(obj, rbytes(obj.filename))
found.add(obj.filename)
else:
ztemp.writestr(obj, z.read_raw(obj), raw_bytes=True)
if add_missing:
for name in names - found:
ztemp.writestr(name, replacements[name].read())
ztemp.writestr(name, rbytes(name))
ztemp.close()
z.close()
temp.seek(0)