Forgot to add support for missing bytestring replacements

This commit is contained in:
Kovid Goyal 2016-06-16 11:13:44 +05:30
parent c9a8bd9884
commit 46f9c7c780
2 changed files with 19 additions and 11 deletions

View File

@ -276,24 +276,28 @@ class LocalZipFile(object):
replacements = {name:datastream} replacements = {name:datastream}
replacements.update(extra_replacements) replacements.update(extra_replacements)
names = frozenset(replacements.keys()) names = frozenset(replacements.keys())
found = set([]) found = set()
def rbytes(name):
r = replacements[name]
if not isinstance(r, bytes):
r = r.read()
return r
with SpooledTemporaryFile(max_size=100*1024*1024) as temp: with SpooledTemporaryFile(max_size=100*1024*1024) as temp:
ztemp = ZipFile(temp, 'w') ztemp = ZipFile(temp, 'w')
for offset, header in self.file_info.itervalues(): for offset, header in self.file_info.itervalues():
if header.filename in names: if header.filename in names:
zi = ZipInfo(header.filename) zi = ZipInfo(header.filename)
zi.compress_type = header.compression_method zi.compress_type = header.compression_method
r = replacements[header.filename] ztemp.writestr(zi, rbytes(header.filename))
if not isinstance(r, bytes):
r = r.read()
ztemp.writestr(zi, r)
found.add(header.filename) found.add(header.filename)
else: else:
ztemp.writestr(header.filename, self.read(header.filename, ztemp.writestr(header.filename, self.read(header.filename,
spool_size=0)) spool_size=0))
if add_missing: if add_missing:
for name in names - found: for name in names - found:
ztemp.writestr(name, replacements[name].read()) ztemp.writestr(name, rbytes(name))
ztemp.close() ztemp.close()
zipstream = self.stream zipstream = self.stream
temp.seek(0) temp.seek(0)

View File

@ -1470,22 +1470,26 @@ def safe_replace(zipstream, name, datastream, extra_replacements={},
replacements.update(extra_replacements) replacements.update(extra_replacements)
names = frozenset(replacements.keys()) names = frozenset(replacements.keys())
found = set([]) found = set([])
def rbytes(name):
r = replacements[name]
if not isinstance(r, bytes):
r = r.read()
return r
with SpooledTemporaryFile(max_size=100*1024*1024) as temp: with SpooledTemporaryFile(max_size=100*1024*1024) as temp:
ztemp = ZipFile(temp, 'w') ztemp = ZipFile(temp, 'w')
for obj in z.infolist(): for obj in z.infolist():
if isinstance(obj.filename, unicode): if isinstance(obj.filename, unicode):
obj.flag_bits |= 0x16 # Set isUTF-8 bit obj.flag_bits |= 0x16 # Set isUTF-8 bit
if obj.filename in names: if obj.filename in names:
r = replacements[obj.filename] ztemp.writestr(obj, rbytes(obj.filename))
if not isinstance(r, bytes):
r = r.read()
ztemp.writestr(obj, r)
found.add(obj.filename) found.add(obj.filename)
else: else:
ztemp.writestr(obj, z.read_raw(obj), raw_bytes=True) ztemp.writestr(obj, z.read_raw(obj), raw_bytes=True)
if add_missing: if add_missing:
for name in names - found: for name in names - found:
ztemp.writestr(name, replacements[name].read()) ztemp.writestr(name, rbytes(name))
ztemp.close() ztemp.close()
z.close() z.close()
temp.seek(0) temp.seek(0)