diff --git a/src/calibre/utils/zipfile.py b/src/calibre/utils/zipfile.py index 6bf6f2f83c..9943119968 100644 --- a/src/calibre/utils/zipfile.py +++ b/src/calibre/utils/zipfile.py @@ -6,8 +6,8 @@ from __future__ import with_statement import struct, os, time, sys, shutil import binascii, cStringIO from contextlib import closing +from tempfile import SpooledTemporaryFile -from calibre.ptempfile import TemporaryDirectory from calibre import sanitize_file_name from calibre.constants import filesystem_encoding from calibre.ebooks.chardet import detect @@ -467,6 +467,7 @@ class ZipExtFile: def __init__(self, fileobj, zipinfo, decrypt=None): self.fileobj = fileobj + self.orig_pos = fileobj.tell() self.decrypter = decrypt self.bytes_read = 0L self.rawbuffer = '' @@ -582,6 +583,20 @@ class ZipExtFile: result.append(line) return result + def read_raw(self): + pos = self.fileobj.tell() + self.fileobj.seek(self.orig_pos) + bytes_to_read = self.compress_size + if self.decrypter is not None: + bytes_to_read -= 12 + raw = b'' + + if bytes_to_read > 0: + raw = self.fileobj.read(bytes_to_read) + self.fileobj.seek(pos) + return raw + + def read(self, size = None): # act like file() obj and return empty string if size is 0 if size == 0: @@ -925,6 +940,11 @@ class ZipFile: """Return file bytes (as a string) for name.""" return self.open(name, "r", pwd).read() + def read_raw(self, name, mode="r", pwd=None): + """Return the raw bytes in the zipfile corresponding to name.""" + zef = self.open(name, mode=mode, pwd=pwd) + return zef.read_raw() + def open(self, name, mode="r", pwd=None): """Return file-like object for 'name'.""" if mode not in ("r", "U", "rU"): @@ -1159,10 +1179,13 @@ class ZipFile: self.filelist.append(zinfo) self.NameToInfo[zinfo.filename] = zinfo - def writestr(self, zinfo_or_arcname, bytes, permissions=0600, compression=ZIP_DEFLATED): + def writestr(self, zinfo_or_arcname, bytes, permissions=0600, + compression=ZIP_DEFLATED, raw_bytes=False): """Write a file into the archive. The contents is the string 'bytes'. 'zinfo_or_arcname' is either a ZipInfo instance or the name of the file in the archive.""" + assert not raw_bytes or (raw_bytes and + isinstance(zinfo_or_arcname, ZipInfo)) if not isinstance(zinfo_or_arcname, ZipInfo): if isinstance(zinfo_or_arcname, unicode): zinfo_or_arcname = zinfo_or_arcname.encode('utf-8') @@ -1177,18 +1200,20 @@ class ZipFile: raise RuntimeError( "Attempt to write to ZIP archive that was already closed") - zinfo.file_size = len(bytes) # Uncompressed size + if not raw_bytes: + zinfo.file_size = len(bytes) # Uncompressed size zinfo.header_offset = self.fp.tell() # Start of header bytes self._writecheck(zinfo) self._didModify = True - zinfo.CRC = crc32(bytes) & 0xffffffff # CRC-32 checksum - if zinfo.compress_type == ZIP_DEFLATED: - co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, - zlib.DEFLATED, -15) - bytes = co.compress(bytes) + co.flush() - zinfo.compress_size = len(bytes) # Compressed size - else: - zinfo.compress_size = zinfo.file_size + if not raw_bytes: + zinfo.CRC = crc32(bytes) & 0xffffffff # CRC-32 checksum + if zinfo.compress_type == ZIP_DEFLATED: + co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, + zlib.DEFLATED, -15) + bytes = co.compress(bytes) + co.flush() + zinfo.compress_size = len(bytes) # Compressed size + else: + zinfo.compress_size = zinfo.file_size zinfo.header_offset = self.fp.tell() # Start of header bytes self.fp.write(zinfo.FileHeader()) self.fp.write(bytes) @@ -1332,7 +1357,7 @@ class ZipFile: def safe_replace(zipstream, name, datastream): ''' Replace a file in a zip file in a safe manner. This proceeds by extracting - and re-creating the zipfile. This is neccessary because :method:`ZipFile.replace` + and re-creating the zipfile. This is necessary because :method:`ZipFile.replace` sometimes created corrupted zip files. :param zipstream: Stream from a zip file @@ -1340,21 +1365,20 @@ def safe_replace(zipstream, name, datastream): :param datastream: The data to replace the file with. ''' z = ZipFile(zipstream, 'r') - names = z.infolist() - with TemporaryDirectory('_zipfile_replace') as tdir: - z.extractall(path=tdir) - mapping = z.extract_mapping - path = os.path.join(tdir, *name.split('/')) - shutil.copyfileobj(datastream, open(path, 'wb')) + with SpooledTemporaryFile(max_size=100*1024*1024) as temp: + ztemp = ZipFile(temp, 'w') + for obj in z.infolist(): + if obj.filename == name: + ztemp.writestr(obj, datastream.read()) + else: + ztemp.writestr(obj, z.read_raw(obj), raw_bytes=True) + ztemp.close() + z.close() + temp.seek(0) zipstream.seek(0) zipstream.truncate() - with closing(ZipFile(zipstream, 'w')) as z: - for info in names: - current = mapping[info.filename] - if os.path.isdir(current): - z.writestr(info.filename+'/', '', 0700) - else: - z.write(current, info.filename, compress_type=info.compress_type) + shutil.copyfileobj(temp, zipstream) + zipstream.flush() class PyZipFile(ZipFile): """Class to create ZIP archives with Python library files and packages."""