Make replacing of files in ZIP archives faster and hopefullt more robust

This commit is contained in:
Kovid Goyal 2010-02-15 11:01:39 -07:00
parent 1c2229d0eb
commit 137d83c0e2

View File

@ -6,8 +6,8 @@ from __future__ import with_statement
import struct, os, time, sys, shutil import struct, os, time, sys, shutil
import binascii, cStringIO import binascii, cStringIO
from contextlib import closing from contextlib import closing
from tempfile import SpooledTemporaryFile
from calibre.ptempfile import TemporaryDirectory
from calibre import sanitize_file_name from calibre import sanitize_file_name
from calibre.constants import filesystem_encoding from calibre.constants import filesystem_encoding
from calibre.ebooks.chardet import detect from calibre.ebooks.chardet import detect
@ -467,6 +467,7 @@ class ZipExtFile:
def __init__(self, fileobj, zipinfo, decrypt=None): def __init__(self, fileobj, zipinfo, decrypt=None):
self.fileobj = fileobj self.fileobj = fileobj
self.orig_pos = fileobj.tell()
self.decrypter = decrypt self.decrypter = decrypt
self.bytes_read = 0L self.bytes_read = 0L
self.rawbuffer = '' self.rawbuffer = ''
@ -582,6 +583,20 @@ class ZipExtFile:
result.append(line) result.append(line)
return result return result
def read_raw(self):
pos = self.fileobj.tell()
self.fileobj.seek(self.orig_pos)
bytes_to_read = self.compress_size
if self.decrypter is not None:
bytes_to_read -= 12
raw = b''
if bytes_to_read > 0:
raw = self.fileobj.read(bytes_to_read)
self.fileobj.seek(pos)
return raw
def read(self, size = None): def read(self, size = None):
# act like file() obj and return empty string if size is 0 # act like file() obj and return empty string if size is 0
if size == 0: if size == 0:
@ -925,6 +940,11 @@ class ZipFile:
"""Return file bytes (as a string) for name.""" """Return file bytes (as a string) for name."""
return self.open(name, "r", pwd).read() return self.open(name, "r", pwd).read()
def read_raw(self, name, mode="r", pwd=None):
"""Return the raw bytes in the zipfile corresponding to name."""
zef = self.open(name, mode=mode, pwd=pwd)
return zef.read_raw()
def open(self, name, mode="r", pwd=None): def open(self, name, mode="r", pwd=None):
"""Return file-like object for 'name'.""" """Return file-like object for 'name'."""
if mode not in ("r", "U", "rU"): if mode not in ("r", "U", "rU"):
@ -1159,10 +1179,13 @@ class ZipFile:
self.filelist.append(zinfo) self.filelist.append(zinfo)
self.NameToInfo[zinfo.filename] = zinfo self.NameToInfo[zinfo.filename] = zinfo
def writestr(self, zinfo_or_arcname, bytes, permissions=0600, compression=ZIP_DEFLATED): def writestr(self, zinfo_or_arcname, bytes, permissions=0600,
compression=ZIP_DEFLATED, raw_bytes=False):
"""Write a file into the archive. The contents is the string """Write a file into the archive. The contents is the string
'bytes'. 'zinfo_or_arcname' is either a ZipInfo instance or 'bytes'. 'zinfo_or_arcname' is either a ZipInfo instance or
the name of the file in the archive.""" the name of the file in the archive."""
assert not raw_bytes or (raw_bytes and
isinstance(zinfo_or_arcname, ZipInfo))
if not isinstance(zinfo_or_arcname, ZipInfo): if not isinstance(zinfo_or_arcname, ZipInfo):
if isinstance(zinfo_or_arcname, unicode): if isinstance(zinfo_or_arcname, unicode):
zinfo_or_arcname = zinfo_or_arcname.encode('utf-8') zinfo_or_arcname = zinfo_or_arcname.encode('utf-8')
@ -1177,18 +1200,20 @@ class ZipFile:
raise RuntimeError( raise RuntimeError(
"Attempt to write to ZIP archive that was already closed") "Attempt to write to ZIP archive that was already closed")
zinfo.file_size = len(bytes) # Uncompressed size if not raw_bytes:
zinfo.file_size = len(bytes) # Uncompressed size
zinfo.header_offset = self.fp.tell() # Start of header bytes zinfo.header_offset = self.fp.tell() # Start of header bytes
self._writecheck(zinfo) self._writecheck(zinfo)
self._didModify = True self._didModify = True
zinfo.CRC = crc32(bytes) & 0xffffffff # CRC-32 checksum if not raw_bytes:
if zinfo.compress_type == ZIP_DEFLATED: zinfo.CRC = crc32(bytes) & 0xffffffff # CRC-32 checksum
co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, if zinfo.compress_type == ZIP_DEFLATED:
zlib.DEFLATED, -15) co = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
bytes = co.compress(bytes) + co.flush() zlib.DEFLATED, -15)
zinfo.compress_size = len(bytes) # Compressed size bytes = co.compress(bytes) + co.flush()
else: zinfo.compress_size = len(bytes) # Compressed size
zinfo.compress_size = zinfo.file_size else:
zinfo.compress_size = zinfo.file_size
zinfo.header_offset = self.fp.tell() # Start of header bytes zinfo.header_offset = self.fp.tell() # Start of header bytes
self.fp.write(zinfo.FileHeader()) self.fp.write(zinfo.FileHeader())
self.fp.write(bytes) self.fp.write(bytes)
@ -1332,7 +1357,7 @@ class ZipFile:
def safe_replace(zipstream, name, datastream): def safe_replace(zipstream, name, datastream):
''' '''
Replace a file in a zip file in a safe manner. This proceeds by extracting Replace a file in a zip file in a safe manner. This proceeds by extracting
and re-creating the zipfile. This is neccessary because :method:`ZipFile.replace` and re-creating the zipfile. This is necessary because :method:`ZipFile.replace`
sometimes created corrupted zip files. sometimes created corrupted zip files.
:param zipstream: Stream from a zip file :param zipstream: Stream from a zip file
@ -1340,21 +1365,20 @@ def safe_replace(zipstream, name, datastream):
:param datastream: The data to replace the file with. :param datastream: The data to replace the file with.
''' '''
z = ZipFile(zipstream, 'r') z = ZipFile(zipstream, 'r')
names = z.infolist() with SpooledTemporaryFile(max_size=100*1024*1024) as temp:
with TemporaryDirectory('_zipfile_replace') as tdir: ztemp = ZipFile(temp, 'w')
z.extractall(path=tdir) for obj in z.infolist():
mapping = z.extract_mapping if obj.filename == name:
path = os.path.join(tdir, *name.split('/')) ztemp.writestr(obj, datastream.read())
shutil.copyfileobj(datastream, open(path, 'wb')) else:
ztemp.writestr(obj, z.read_raw(obj), raw_bytes=True)
ztemp.close()
z.close()
temp.seek(0)
zipstream.seek(0) zipstream.seek(0)
zipstream.truncate() zipstream.truncate()
with closing(ZipFile(zipstream, 'w')) as z: shutil.copyfileobj(temp, zipstream)
for info in names: zipstream.flush()
current = mapping[info.filename]
if os.path.isdir(current):
z.writestr(info.filename+'/', '', 0700)
else:
z.write(current, info.filename, compress_type=info.compress_type)
class PyZipFile(ZipFile): class PyZipFile(ZipFile):
"""Class to create ZIP archives with Python library files and packages.""" """Class to create ZIP archives with Python library files and packages."""