mirror of
				https://github.com/kovidgoyal/calibre.git
				synced 2025-11-03 19:17:02 -05:00 
			
		
		
		
	Speed up zipfile.extractall by using threads
Allows I/O and decompress to be done in parallel Speed up is another 30% on decent hardware.
This commit is contained in:
		
							parent
							
								
									591bbcbb63
								
							
						
					
					
						commit
						2d01ad73c6
					
				@ -2,7 +2,6 @@
 | 
				
			|||||||
Read and write ZIP files. Modified by Kovid Goyal to support replacing files in
 | 
					Read and write ZIP files. Modified by Kovid Goyal to support replacing files in
 | 
				
			||||||
a zip archive, detecting filename encoding, updating zip files, etc.
 | 
					a zip archive, detecting filename encoding, updating zip files, etc.
 | 
				
			||||||
'''
 | 
					'''
 | 
				
			||||||
import binascii
 | 
					 | 
				
			||||||
import io
 | 
					import io
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
@ -11,9 +10,11 @@ import stat
 | 
				
			|||||||
import struct
 | 
					import struct
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					import zlib
 | 
				
			||||||
from concurrent.futures import ThreadPoolExecutor
 | 
					from concurrent.futures import ThreadPoolExecutor
 | 
				
			||||||
from contextlib import closing
 | 
					from contextlib import closing, suppress
 | 
				
			||||||
from threading import Lock
 | 
					from threading import Lock
 | 
				
			||||||
 | 
					from zlib import crc32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from calibre import sanitize_file_name
 | 
					from calibre import sanitize_file_name
 | 
				
			||||||
from calibre.constants import filesystem_encoding
 | 
					from calibre.constants import filesystem_encoding
 | 
				
			||||||
@ -22,11 +23,9 @@ from calibre.ptempfile import SpooledTemporaryFile
 | 
				
			|||||||
from polyglot.builtins import as_bytes, string_or_bytes
 | 
					from polyglot.builtins import as_bytes, string_or_bytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    import zlib  # We may need its compression method
 | 
					    from calibre_extensions.speedup import pread_all
 | 
				
			||||||
    crc32 = zlib.crc32
 | 
					 | 
				
			||||||
except ImportError:
 | 
					except ImportError:
 | 
				
			||||||
    zlib = None
 | 
					    pread_all = None  # running from source
 | 
				
			||||||
    crc32 = binascii.crc32
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = [
 | 
					__all__ = [
 | 
				
			||||||
    'ZIP_DEFLATED',
 | 
					    'ZIP_DEFLATED',
 | 
				
			||||||
@ -527,10 +526,11 @@ class ZipExtFile(io.BufferedIOBase):
 | 
				
			|||||||
    # Search for universal newlines or line chunks.
 | 
					    # Search for universal newlines or line chunks.
 | 
				
			||||||
    PATTERN = re.compile(br'^(?P<chunk>[^\r\n]+)|(?P<newline>\n|\r\n?)')
 | 
					    PATTERN = re.compile(br'^(?P<chunk>[^\r\n]+)|(?P<newline>\n|\r\n?)')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, fileobj, mode, zipinfo, decrypter=None):
 | 
					    def __init__(self, fileobj, pos, mode, zipinfo, decrypter=None, pread_fd=-1):
 | 
				
			||||||
        self._fileobj = fileobj
 | 
					        self._fileobj = fileobj
 | 
				
			||||||
 | 
					        self._pread_fd = pread_fd
 | 
				
			||||||
        self._decrypter = decrypter
 | 
					        self._decrypter = decrypter
 | 
				
			||||||
        self._orig_pos = fileobj.tell()
 | 
					        self._orig_pos = pos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._compress_type = zipinfo.compress_type
 | 
					        self._compress_type = zipinfo.compress_type
 | 
				
			||||||
        self._compress_size = zipinfo.compress_size
 | 
					        self._compress_size = zipinfo.compress_size
 | 
				
			||||||
@ -704,16 +704,18 @@ class ZipExtFile(io.BufferedIOBase):
 | 
				
			|||||||
        return data
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def read_raw(self):
 | 
					    def read_raw(self):
 | 
				
			||||||
        pos = self._fileobj.tell()
 | 
					 | 
				
			||||||
        self._fileobj.seek(self._orig_pos)
 | 
					 | 
				
			||||||
        bytes_to_read = self._compress_size
 | 
					        bytes_to_read = self._compress_size
 | 
				
			||||||
        if self._decrypter is not None:
 | 
					        if self._decrypter is not None:
 | 
				
			||||||
            bytes_to_read -= 12
 | 
					            bytes_to_read -= 12
 | 
				
			||||||
        raw = b''
 | 
					        raw = b''
 | 
				
			||||||
 | 
					 | 
				
			||||||
        if bytes_to_read > 0:
 | 
					        if bytes_to_read > 0:
 | 
				
			||||||
            raw = self._fileobj.read(bytes_to_read)
 | 
					            if self._pread_fd > -1:
 | 
				
			||||||
        self._fileobj.seek(pos)
 | 
					                raw = pread_all(self._pread_fd, bytes_to_read, self._orig_pos)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                pos = self._fileobj.tell()
 | 
				
			||||||
 | 
					                self._fileobj.seek(self._orig_pos)
 | 
				
			||||||
 | 
					                raw = self._fileobj.read(bytes_to_read)
 | 
				
			||||||
 | 
					                self._fileobj.seek(pos)
 | 
				
			||||||
        return raw
 | 
					        return raw
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def decrypt_and_uncompress(self, raw: bytes) -> bytes:
 | 
					    def decrypt_and_uncompress(self, raw: bytes) -> bytes:
 | 
				
			||||||
@ -759,13 +761,7 @@ class ZipFile:
 | 
				
			|||||||
        if mode not in ('r', 'w', 'a'):
 | 
					        if mode not in ('r', 'w', 'a'):
 | 
				
			||||||
            raise RuntimeError(f'ZipFile() requires mode "r", "w", or "a" not {mode}')
 | 
					            raise RuntimeError(f'ZipFile() requires mode "r", "w", or "a" not {mode}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if compression == ZIP_STORED:
 | 
					        if compression not in (ZIP_STORED, ZIP_DEFLATED):
 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
        elif compression == ZIP_DEFLATED:
 | 
					 | 
				
			||||||
            if not zlib:
 | 
					 | 
				
			||||||
                raise RuntimeError(
 | 
					 | 
				
			||||||
                      'Compression requires the (missing) zlib module')
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise RuntimeError(f'The compression method {compression} is not supported')
 | 
					            raise RuntimeError(f'The compression method {compression} is not supported')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._allowZip64 = allowZip64
 | 
					        self._allowZip64 = allowZip64
 | 
				
			||||||
@ -795,6 +791,16 @@ class ZipFile:
 | 
				
			|||||||
            self._filePassed = 1
 | 
					            self._filePassed = 1
 | 
				
			||||||
            self.fp = file
 | 
					            self.fp = file
 | 
				
			||||||
            self.filename = getattr(file, 'name', None)
 | 
					            self.filename = getattr(file, 'name', None)
 | 
				
			||||||
 | 
					        self.pread_fd = -1
 | 
				
			||||||
 | 
					        if pread_all is not None:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                fd = self.fp.fileno()
 | 
				
			||||||
 | 
					            except Exception:
 | 
				
			||||||
 | 
					                fd = -1
 | 
				
			||||||
 | 
					            if fd > -1:
 | 
				
			||||||
 | 
					                with suppress(Exception):
 | 
				
			||||||
 | 
					                    pread_all(fd, 1, 0)
 | 
				
			||||||
 | 
					                    self.pread_fd = fd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if key == 'r':
 | 
					        if key == 'r':
 | 
				
			||||||
            self._GetContents()
 | 
					            self._GetContents()
 | 
				
			||||||
@ -1044,7 +1050,6 @@ class ZipFile:
 | 
				
			|||||||
            raise RuntimeError(
 | 
					            raise RuntimeError(
 | 
				
			||||||
                  'Attempt to read ZIP archive that was already closed')
 | 
					                  'Attempt to read ZIP archive that was already closed')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        zef_file = self.fp
 | 
					 | 
				
			||||||
        # Make sure we have an info object
 | 
					        # Make sure we have an info object
 | 
				
			||||||
        if isinstance(name, ZipInfo):
 | 
					        if isinstance(name, ZipInfo):
 | 
				
			||||||
            # 'name' is already an info object
 | 
					            # 'name' is already an info object
 | 
				
			||||||
@ -1053,18 +1058,31 @@ class ZipFile:
 | 
				
			|||||||
            # Get info object for name
 | 
					            # Get info object for name
 | 
				
			||||||
            zinfo = self.getinfo(name)
 | 
					            zinfo = self.getinfo(name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        zef_file.seek(zinfo.header_offset, os.SEEK_SET)
 | 
					        pos = zinfo.header_offset
 | 
				
			||||||
 | 
					        if self.pread_fd > -1:
 | 
				
			||||||
 | 
					            def read(n):
 | 
				
			||||||
 | 
					                nonlocal pos
 | 
				
			||||||
 | 
					                ans = pread_all(self.pread_fd, n, pos)
 | 
				
			||||||
 | 
					                pos += len(ans)
 | 
				
			||||||
 | 
					                return ans
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.fp.seek(zinfo.header_offset, os.SEEK_SET)
 | 
				
			||||||
 | 
					            def read(n):
 | 
				
			||||||
 | 
					                nonlocal pos
 | 
				
			||||||
 | 
					                ans = self.fp.read(n)
 | 
				
			||||||
 | 
					                pos += len(ans)
 | 
				
			||||||
 | 
					                return ans
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Skip the file header:
 | 
					        # Skip the file header:
 | 
				
			||||||
        fheader = zef_file.read(sizeFileHeader)
 | 
					        fheader = read(sizeFileHeader)
 | 
				
			||||||
        if fheader[0:4] != stringFileHeader:
 | 
					        if fheader[0:4] != stringFileHeader:
 | 
				
			||||||
            raise BadZipfile('Bad magic number for file header')
 | 
					            raise BadZipfile('Bad magic number for file header')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        fheader = struct.unpack(structFileHeader, fheader)
 | 
					        fheader = struct.unpack(structFileHeader, fheader)
 | 
				
			||||||
        fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
 | 
					        fname = read(fheader[_FH_FILENAME_LENGTH])
 | 
				
			||||||
        fname = decode_zip_internal_file_name(fname, zinfo.flag_bits)
 | 
					        fname = decode_zip_internal_file_name(fname, zinfo.flag_bits)
 | 
				
			||||||
        if fheader[_FH_EXTRA_FIELD_LENGTH]:
 | 
					        if fheader[_FH_EXTRA_FIELD_LENGTH]:
 | 
				
			||||||
            zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
 | 
					            read(fheader[_FH_EXTRA_FIELD_LENGTH])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if fname != zinfo.orig_filename:
 | 
					        if fname != zinfo.orig_filename:
 | 
				
			||||||
            print(f'WARNING: Header ({fname!r}) and directory ({zinfo.orig_filename!r}) filenames do not'
 | 
					            print(f'WARNING: Header ({fname!r}) and directory ({zinfo.orig_filename!r}) filenames do not'
 | 
				
			||||||
@ -1090,7 +1108,7 @@ class ZipFile:
 | 
				
			|||||||
            # completely random, while the 12th contains the MSB of the CRC,
 | 
					            # completely random, while the 12th contains the MSB of the CRC,
 | 
				
			||||||
            # or the MSB of the file time depending on the header type
 | 
					            # or the MSB of the file time depending on the header type
 | 
				
			||||||
            # and is used to check the correctness of the password.
 | 
					            # and is used to check the correctness of the password.
 | 
				
			||||||
            byts = zef_file.read(12)
 | 
					            byts = read(12)
 | 
				
			||||||
            h = list(map(zd, bytearray(byts[0:12])))
 | 
					            h = list(map(zd, bytearray(byts[0:12])))
 | 
				
			||||||
            if zinfo.flag_bits & 0x8:
 | 
					            if zinfo.flag_bits & 0x8:
 | 
				
			||||||
                # compare against the file type from extended local headers
 | 
					                # compare against the file type from extended local headers
 | 
				
			||||||
@ -1101,7 +1119,7 @@ class ZipFile:
 | 
				
			|||||||
            if h[11] != check_byte:
 | 
					            if h[11] != check_byte:
 | 
				
			||||||
                raise RuntimeError('Bad password for file', name)
 | 
					                raise RuntimeError('Bad password for file', name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return ZipExtFile(zef_file, mode, zinfo, zd)
 | 
					        return ZipExtFile(self.fp, pos, mode, zinfo, zd, self.pread_fd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def extract(self, member, path=None, pwd=None):
 | 
					    def extract(self, member, path=None, pwd=None):
 | 
				
			||||||
        '''Extract a member from the archive to the current working directory,
 | 
					        '''Extract a member from the archive to the current working directory,
 | 
				
			||||||
@ -1138,7 +1156,7 @@ class ZipFile:
 | 
				
			|||||||
        lock = Lock()
 | 
					        lock = Lock()
 | 
				
			||||||
        def do_one(a):
 | 
					        def do_one(a):
 | 
				
			||||||
            return self._extract_member_to(*a, lock=lock)
 | 
					            return self._extract_member_to(*a, lock=lock)
 | 
				
			||||||
        with ThreadPoolExecutor(thread_name_prefix='ZipFile-') as e:
 | 
					        with ThreadPoolExecutor(max_workers=12, thread_name_prefix='ZipFile-') as e:
 | 
				
			||||||
            tuple(e.map(do_one, args))
 | 
					            tuple(e.map(do_one, args))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_targetpath(self, member: ZipInfo, targetpath: str) -> str:
 | 
					    def _get_targetpath(self, member: ZipInfo, targetpath: str) -> str:
 | 
				
			||||||
@ -1202,12 +1220,16 @@ class ZipFile:
 | 
				
			|||||||
            target = open(targetpath, 'wb')
 | 
					            target = open(targetpath, 'wb')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with target:
 | 
					        with target:
 | 
				
			||||||
            if lock is None:
 | 
					            if max(member.compress_size, member.file_size) > 256*1024*1024:
 | 
				
			||||||
                with closing(self.open(member, pwd=pwd)) as source:
 | 
					                with lock, closing(self.open(member, pwd=pwd)) as source:
 | 
				
			||||||
                    shutil.copyfileobj(source, target)
 | 
					                    shutil.copyfileobj(source, target)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                with lock, closing(self.open(member, pwd=pwd)) as source:
 | 
					                if self.pread_fd > -1:
 | 
				
			||||||
                    src = source.read_raw()
 | 
					                    with closing(self.open(member, pwd=pwd)) as source:
 | 
				
			||||||
 | 
					                        src = source.read_raw()
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    with lock, closing(self.open(member, pwd=pwd)) as source:
 | 
				
			||||||
 | 
					                        src = source.read_raw()
 | 
				
			||||||
                src = source.decrypt_and_uncompress(src)
 | 
					                src = source.decrypt_and_uncompress(src)
 | 
				
			||||||
                source.check_crc(src)
 | 
					                source.check_crc(src)
 | 
				
			||||||
                target.write(src)
 | 
					                target.write(src)
 | 
				
			||||||
@ -1703,10 +1725,7 @@ def main(args=None):
 | 
				
			|||||||
            print(USAGE)
 | 
					            print(USAGE)
 | 
				
			||||||
            sys.exit(1)
 | 
					            sys.exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        zf = ZipFile(args[1], 'r')
 | 
					        extractall(args[1], args[2])
 | 
				
			||||||
        out = args[2]
 | 
					 | 
				
			||||||
        zf.extractall(out)
 | 
					 | 
				
			||||||
        zf.close()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif args[0] == '-c':
 | 
					    elif args[0] == '-c':
 | 
				
			||||||
        if len(args) < 3:
 | 
					        if len(args) < 3:
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user