mirror of
				https://github.com/kovidgoyal/calibre.git
				synced 2025-11-03 19:17:02 -05:00 
			
		
		
		
	Speed up zipfile.extractall by using threads
Allows I/O and decompress to be done in parallel Speed up is another 30% on decent hardware.
This commit is contained in:
		
							parent
							
								
									591bbcbb63
								
							
						
					
					
						commit
						2d01ad73c6
					
				@ -2,7 +2,6 @@
 | 
			
		||||
Read and write ZIP files. Modified by Kovid Goyal to support replacing files in
 | 
			
		||||
a zip archive, detecting filename encoding, updating zip files, etc.
 | 
			
		||||
'''
 | 
			
		||||
import binascii
 | 
			
		||||
import io
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
@ -11,9 +10,11 @@ import stat
 | 
			
		||||
import struct
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import zlib
 | 
			
		||||
from concurrent.futures import ThreadPoolExecutor
 | 
			
		||||
from contextlib import closing
 | 
			
		||||
from contextlib import closing, suppress
 | 
			
		||||
from threading import Lock
 | 
			
		||||
from zlib import crc32
 | 
			
		||||
 | 
			
		||||
from calibre import sanitize_file_name
 | 
			
		||||
from calibre.constants import filesystem_encoding
 | 
			
		||||
@ -22,11 +23,9 @@ from calibre.ptempfile import SpooledTemporaryFile
 | 
			
		||||
from polyglot.builtins import as_bytes, string_or_bytes
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import zlib  # We may need its compression method
 | 
			
		||||
    crc32 = zlib.crc32
 | 
			
		||||
    from calibre_extensions.speedup import pread_all
 | 
			
		||||
except ImportError:
 | 
			
		||||
    zlib = None
 | 
			
		||||
    crc32 = binascii.crc32
 | 
			
		||||
    pread_all = None  # running from source
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'ZIP_DEFLATED',
 | 
			
		||||
@ -527,10 +526,11 @@ class ZipExtFile(io.BufferedIOBase):
 | 
			
		||||
    # Search for universal newlines or line chunks.
 | 
			
		||||
    PATTERN = re.compile(br'^(?P<chunk>[^\r\n]+)|(?P<newline>\n|\r\n?)')
 | 
			
		||||
 | 
			
		||||
    def __init__(self, fileobj, mode, zipinfo, decrypter=None):
 | 
			
		||||
    def __init__(self, fileobj, pos, mode, zipinfo, decrypter=None, pread_fd=-1):
 | 
			
		||||
        self._fileobj = fileobj
 | 
			
		||||
        self._pread_fd = pread_fd
 | 
			
		||||
        self._decrypter = decrypter
 | 
			
		||||
        self._orig_pos = fileobj.tell()
 | 
			
		||||
        self._orig_pos = pos
 | 
			
		||||
 | 
			
		||||
        self._compress_type = zipinfo.compress_type
 | 
			
		||||
        self._compress_size = zipinfo.compress_size
 | 
			
		||||
@ -704,16 +704,18 @@ class ZipExtFile(io.BufferedIOBase):
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def read_raw(self):
 | 
			
		||||
        pos = self._fileobj.tell()
 | 
			
		||||
        self._fileobj.seek(self._orig_pos)
 | 
			
		||||
        bytes_to_read = self._compress_size
 | 
			
		||||
        if self._decrypter is not None:
 | 
			
		||||
            bytes_to_read -= 12
 | 
			
		||||
        raw = b''
 | 
			
		||||
 | 
			
		||||
        if bytes_to_read > 0:
 | 
			
		||||
            raw = self._fileobj.read(bytes_to_read)
 | 
			
		||||
        self._fileobj.seek(pos)
 | 
			
		||||
            if self._pread_fd > -1:
 | 
			
		||||
                raw = pread_all(self._pread_fd, bytes_to_read, self._orig_pos)
 | 
			
		||||
            else:
 | 
			
		||||
                pos = self._fileobj.tell()
 | 
			
		||||
                self._fileobj.seek(self._orig_pos)
 | 
			
		||||
                raw = self._fileobj.read(bytes_to_read)
 | 
			
		||||
                self._fileobj.seek(pos)
 | 
			
		||||
        return raw
 | 
			
		||||
 | 
			
		||||
    def decrypt_and_uncompress(self, raw: bytes) -> bytes:
 | 
			
		||||
@ -759,13 +761,7 @@ class ZipFile:
 | 
			
		||||
        if mode not in ('r', 'w', 'a'):
 | 
			
		||||
            raise RuntimeError(f'ZipFile() requires mode "r", "w", or "a" not {mode}')
 | 
			
		||||
 | 
			
		||||
        if compression == ZIP_STORED:
 | 
			
		||||
            pass
 | 
			
		||||
        elif compression == ZIP_DEFLATED:
 | 
			
		||||
            if not zlib:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                      'Compression requires the (missing) zlib module')
 | 
			
		||||
        else:
 | 
			
		||||
        if compression not in (ZIP_STORED, ZIP_DEFLATED):
 | 
			
		||||
            raise RuntimeError(f'The compression method {compression} is not supported')
 | 
			
		||||
 | 
			
		||||
        self._allowZip64 = allowZip64
 | 
			
		||||
@ -795,6 +791,16 @@ class ZipFile:
 | 
			
		||||
            self._filePassed = 1
 | 
			
		||||
            self.fp = file
 | 
			
		||||
            self.filename = getattr(file, 'name', None)
 | 
			
		||||
        self.pread_fd = -1
 | 
			
		||||
        if pread_all is not None:
 | 
			
		||||
            try:
 | 
			
		||||
                fd = self.fp.fileno()
 | 
			
		||||
            except Exception:
 | 
			
		||||
                fd = -1
 | 
			
		||||
            if fd > -1:
 | 
			
		||||
                with suppress(Exception):
 | 
			
		||||
                    pread_all(fd, 1, 0)
 | 
			
		||||
                    self.pread_fd = fd
 | 
			
		||||
 | 
			
		||||
        if key == 'r':
 | 
			
		||||
            self._GetContents()
 | 
			
		||||
@ -1044,7 +1050,6 @@ class ZipFile:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                  'Attempt to read ZIP archive that was already closed')
 | 
			
		||||
 | 
			
		||||
        zef_file = self.fp
 | 
			
		||||
        # Make sure we have an info object
 | 
			
		||||
        if isinstance(name, ZipInfo):
 | 
			
		||||
            # 'name' is already an info object
 | 
			
		||||
@ -1053,18 +1058,31 @@ class ZipFile:
 | 
			
		||||
            # Get info object for name
 | 
			
		||||
            zinfo = self.getinfo(name)
 | 
			
		||||
 | 
			
		||||
        zef_file.seek(zinfo.header_offset, os.SEEK_SET)
 | 
			
		||||
        pos = zinfo.header_offset
 | 
			
		||||
        if self.pread_fd > -1:
 | 
			
		||||
            def read(n):
 | 
			
		||||
                nonlocal pos
 | 
			
		||||
                ans = pread_all(self.pread_fd, n, pos)
 | 
			
		||||
                pos += len(ans)
 | 
			
		||||
                return ans
 | 
			
		||||
        else:
 | 
			
		||||
            self.fp.seek(zinfo.header_offset, os.SEEK_SET)
 | 
			
		||||
            def read(n):
 | 
			
		||||
                nonlocal pos
 | 
			
		||||
                ans = self.fp.read(n)
 | 
			
		||||
                pos += len(ans)
 | 
			
		||||
                return ans
 | 
			
		||||
 | 
			
		||||
        # Skip the file header:
 | 
			
		||||
        fheader = zef_file.read(sizeFileHeader)
 | 
			
		||||
        fheader = read(sizeFileHeader)
 | 
			
		||||
        if fheader[0:4] != stringFileHeader:
 | 
			
		||||
            raise BadZipfile('Bad magic number for file header')
 | 
			
		||||
 | 
			
		||||
        fheader = struct.unpack(structFileHeader, fheader)
 | 
			
		||||
        fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
 | 
			
		||||
        fname = read(fheader[_FH_FILENAME_LENGTH])
 | 
			
		||||
        fname = decode_zip_internal_file_name(fname, zinfo.flag_bits)
 | 
			
		||||
        if fheader[_FH_EXTRA_FIELD_LENGTH]:
 | 
			
		||||
            zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
 | 
			
		||||
            read(fheader[_FH_EXTRA_FIELD_LENGTH])
 | 
			
		||||
 | 
			
		||||
        if fname != zinfo.orig_filename:
 | 
			
		||||
            print(f'WARNING: Header ({fname!r}) and directory ({zinfo.orig_filename!r}) filenames do not'
 | 
			
		||||
@ -1090,7 +1108,7 @@ class ZipFile:
 | 
			
		||||
            # completely random, while the 12th contains the MSB of the CRC,
 | 
			
		||||
            # or the MSB of the file time depending on the header type
 | 
			
		||||
            # and is used to check the correctness of the password.
 | 
			
		||||
            byts = zef_file.read(12)
 | 
			
		||||
            byts = read(12)
 | 
			
		||||
            h = list(map(zd, bytearray(byts[0:12])))
 | 
			
		||||
            if zinfo.flag_bits & 0x8:
 | 
			
		||||
                # compare against the file type from extended local headers
 | 
			
		||||
@ -1101,7 +1119,7 @@ class ZipFile:
 | 
			
		||||
            if h[11] != check_byte:
 | 
			
		||||
                raise RuntimeError('Bad password for file', name)
 | 
			
		||||
 | 
			
		||||
        return ZipExtFile(zef_file, mode, zinfo, zd)
 | 
			
		||||
        return ZipExtFile(self.fp, pos, mode, zinfo, zd, self.pread_fd)
 | 
			
		||||
 | 
			
		||||
    def extract(self, member, path=None, pwd=None):
 | 
			
		||||
        '''Extract a member from the archive to the current working directory,
 | 
			
		||||
@ -1138,7 +1156,7 @@ class ZipFile:
 | 
			
		||||
        lock = Lock()
 | 
			
		||||
        def do_one(a):
 | 
			
		||||
            return self._extract_member_to(*a, lock=lock)
 | 
			
		||||
        with ThreadPoolExecutor(thread_name_prefix='ZipFile-') as e:
 | 
			
		||||
        with ThreadPoolExecutor(max_workers=12, thread_name_prefix='ZipFile-') as e:
 | 
			
		||||
            tuple(e.map(do_one, args))
 | 
			
		||||
 | 
			
		||||
    def _get_targetpath(self, member: ZipInfo, targetpath: str) -> str:
 | 
			
		||||
@ -1202,12 +1220,16 @@ class ZipFile:
 | 
			
		||||
            target = open(targetpath, 'wb')
 | 
			
		||||
 | 
			
		||||
        with target:
 | 
			
		||||
            if lock is None:
 | 
			
		||||
                with closing(self.open(member, pwd=pwd)) as source:
 | 
			
		||||
            if max(member.compress_size, member.file_size) > 256*1024*1024:
 | 
			
		||||
                with lock, closing(self.open(member, pwd=pwd)) as source:
 | 
			
		||||
                    shutil.copyfileobj(source, target)
 | 
			
		||||
            else:
 | 
			
		||||
                with lock, closing(self.open(member, pwd=pwd)) as source:
 | 
			
		||||
                    src = source.read_raw()
 | 
			
		||||
                if self.pread_fd > -1:
 | 
			
		||||
                    with closing(self.open(member, pwd=pwd)) as source:
 | 
			
		||||
                        src = source.read_raw()
 | 
			
		||||
                else:
 | 
			
		||||
                    with lock, closing(self.open(member, pwd=pwd)) as source:
 | 
			
		||||
                        src = source.read_raw()
 | 
			
		||||
                src = source.decrypt_and_uncompress(src)
 | 
			
		||||
                source.check_crc(src)
 | 
			
		||||
                target.write(src)
 | 
			
		||||
@ -1703,10 +1725,7 @@ def main(args=None):
 | 
			
		||||
            print(USAGE)
 | 
			
		||||
            sys.exit(1)
 | 
			
		||||
 | 
			
		||||
        zf = ZipFile(args[1], 'r')
 | 
			
		||||
        out = args[2]
 | 
			
		||||
        zf.extractall(out)
 | 
			
		||||
        zf.close()
 | 
			
		||||
        extractall(args[1], args[2])
 | 
			
		||||
 | 
			
		||||
    elif args[0] == '-c':
 | 
			
		||||
        if len(args) < 3:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user