diff --git a/src/calibre/utils/zipfile.py b/src/calibre/utils/zipfile.py index 5c79465dad..1cb658fc62 100644 --- a/src/calibre/utils/zipfile.py +++ b/src/calibre/utils/zipfile.py @@ -2,7 +2,6 @@ Read and write ZIP files. Modified by Kovid Goyal to support replacing files in a zip archive, detecting filename encoding, updating zip files, etc. ''' -import binascii import io import os import re @@ -11,9 +10,11 @@ import stat import struct import sys import time +import zlib from concurrent.futures import ThreadPoolExecutor -from contextlib import closing +from contextlib import closing, suppress from threading import Lock +from zlib import crc32 from calibre import sanitize_file_name from calibre.constants import filesystem_encoding @@ -22,11 +23,9 @@ from calibre.ptempfile import SpooledTemporaryFile from polyglot.builtins import as_bytes, string_or_bytes try: - import zlib # We may need its compression method - crc32 = zlib.crc32 + from calibre_extensions.speedup import pread_all except ImportError: - zlib = None - crc32 = binascii.crc32 + pread_all = None # running from source __all__ = [ 'ZIP_DEFLATED', @@ -527,10 +526,11 @@ class ZipExtFile(io.BufferedIOBase): # Search for universal newlines or line chunks. PATTERN = re.compile(br'^(?P[^\r\n]+)|(?P\n|\r\n?)') - def __init__(self, fileobj, mode, zipinfo, decrypter=None): + def __init__(self, fileobj, pos, mode, zipinfo, decrypter=None, pread_fd=-1): self._fileobj = fileobj + self._pread_fd = pread_fd self._decrypter = decrypter - self._orig_pos = fileobj.tell() + self._orig_pos = pos self._compress_type = zipinfo.compress_type self._compress_size = zipinfo.compress_size @@ -704,16 +704,18 @@ class ZipExtFile(io.BufferedIOBase): return data def read_raw(self): - pos = self._fileobj.tell() - self._fileobj.seek(self._orig_pos) bytes_to_read = self._compress_size if self._decrypter is not None: bytes_to_read -= 12 raw = b'' - if bytes_to_read > 0: - raw = self._fileobj.read(bytes_to_read) - self._fileobj.seek(pos) + if self._pread_fd > -1: + raw = pread_all(self._pread_fd, bytes_to_read, self._orig_pos) + else: + pos = self._fileobj.tell() + self._fileobj.seek(self._orig_pos) + raw = self._fileobj.read(bytes_to_read) + self._fileobj.seek(pos) return raw def decrypt_and_uncompress(self, raw: bytes) -> bytes: @@ -759,13 +761,7 @@ class ZipFile: if mode not in ('r', 'w', 'a'): raise RuntimeError(f'ZipFile() requires mode "r", "w", or "a" not {mode}') - if compression == ZIP_STORED: - pass - elif compression == ZIP_DEFLATED: - if not zlib: - raise RuntimeError( - 'Compression requires the (missing) zlib module') - else: + if compression not in (ZIP_STORED, ZIP_DEFLATED): raise RuntimeError(f'The compression method {compression} is not supported') self._allowZip64 = allowZip64 @@ -795,6 +791,16 @@ class ZipFile: self._filePassed = 1 self.fp = file self.filename = getattr(file, 'name', None) + self.pread_fd = -1 + if pread_all is not None: + try: + fd = self.fp.fileno() + except Exception: + fd = -1 + if fd > -1: + with suppress(Exception): + pread_all(fd, 1, 0) + self.pread_fd = fd if key == 'r': self._GetContents() @@ -1044,7 +1050,6 @@ class ZipFile: raise RuntimeError( 'Attempt to read ZIP archive that was already closed') - zef_file = self.fp # Make sure we have an info object if isinstance(name, ZipInfo): # 'name' is already an info object @@ -1053,18 +1058,31 @@ class ZipFile: # Get info object for name zinfo = self.getinfo(name) - zef_file.seek(zinfo.header_offset, os.SEEK_SET) + pos = zinfo.header_offset + if self.pread_fd > -1: + def read(n): + nonlocal pos + ans = pread_all(self.pread_fd, n, pos) + pos += len(ans) + return ans + else: + self.fp.seek(zinfo.header_offset, os.SEEK_SET) + def read(n): + nonlocal pos + ans = self.fp.read(n) + pos += len(ans) + return ans # Skip the file header: - fheader = zef_file.read(sizeFileHeader) + fheader = read(sizeFileHeader) if fheader[0:4] != stringFileHeader: raise BadZipfile('Bad magic number for file header') fheader = struct.unpack(structFileHeader, fheader) - fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + fname = read(fheader[_FH_FILENAME_LENGTH]) fname = decode_zip_internal_file_name(fname, zinfo.flag_bits) if fheader[_FH_EXTRA_FIELD_LENGTH]: - zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH]) + read(fheader[_FH_EXTRA_FIELD_LENGTH]) if fname != zinfo.orig_filename: print(f'WARNING: Header ({fname!r}) and directory ({zinfo.orig_filename!r}) filenames do not' @@ -1090,7 +1108,7 @@ class ZipFile: # completely random, while the 12th contains the MSB of the CRC, # or the MSB of the file time depending on the header type # and is used to check the correctness of the password. - byts = zef_file.read(12) + byts = read(12) h = list(map(zd, bytearray(byts[0:12]))) if zinfo.flag_bits & 0x8: # compare against the file type from extended local headers @@ -1101,7 +1119,7 @@ class ZipFile: if h[11] != check_byte: raise RuntimeError('Bad password for file', name) - return ZipExtFile(zef_file, mode, zinfo, zd) + return ZipExtFile(self.fp, pos, mode, zinfo, zd, self.pread_fd) def extract(self, member, path=None, pwd=None): '''Extract a member from the archive to the current working directory, @@ -1138,7 +1156,7 @@ class ZipFile: lock = Lock() def do_one(a): return self._extract_member_to(*a, lock=lock) - with ThreadPoolExecutor(thread_name_prefix='ZipFile-') as e: + with ThreadPoolExecutor(max_workers=12, thread_name_prefix='ZipFile-') as e: tuple(e.map(do_one, args)) def _get_targetpath(self, member: ZipInfo, targetpath: str) -> str: @@ -1202,12 +1220,16 @@ class ZipFile: target = open(targetpath, 'wb') with target: - if lock is None: - with closing(self.open(member, pwd=pwd)) as source: + if max(member.compress_size, member.file_size) > 256*1024*1024: + with lock, closing(self.open(member, pwd=pwd)) as source: shutil.copyfileobj(source, target) else: - with lock, closing(self.open(member, pwd=pwd)) as source: - src = source.read_raw() + if self.pread_fd > -1: + with closing(self.open(member, pwd=pwd)) as source: + src = source.read_raw() + else: + with lock, closing(self.open(member, pwd=pwd)) as source: + src = source.read_raw() src = source.decrypt_and_uncompress(src) source.check_crc(src) target.write(src) @@ -1703,10 +1725,7 @@ def main(args=None): print(USAGE) sys.exit(1) - zf = ZipFile(args[1], 'r') - out = args[2] - zf.extractall(out) - zf.close() + extractall(args[1], args[2]) elif args[0] == '-c': if len(args) < 3: