More zipfile fixes

This commit is contained in:
Kovid Goyal 2011-03-23 11:16:44 -06:00
parent 24663f7853
commit a07fcf2fea

View File

@ -2,7 +2,7 @@
Read and write ZIP files. Modified by Kovid Goyal to support replacing files in Read and write ZIP files. Modified by Kovid Goyal to support replacing files in
a zip archive, detecting filename encoding, updating zip files, etc. a zip archive, detecting filename encoding, updating zip files, etc.
""" """
import struct, os, time, sys, shutil, stat import struct, os, time, sys, shutil, stat, re, io
import binascii, cStringIO import binascii, cStringIO
from contextlib import closing from contextlib import closing
from tempfile import SpooledTemporaryFile from tempfile import SpooledTemporaryFile
@ -477,215 +477,209 @@ class _ZipDecrypter:
self._UpdateKeys(c) self._UpdateKeys(c)
return c return c
class ZipExtFile: class ZipExtFile(io.BufferedIOBase):
"""File-like object for reading an archive member. """File-like object for reading an archive member.
Is returned by ZipFile.open(). Is returned by ZipFile.open().
""" """
def __init__(self, fileobj, zipinfo, decrypt=None): # Max size supported by decompressor.
self.fileobj = fileobj MAX_N = 1 << 31 - 1
self.orig_pos = fileobj.tell()
self.decrypter = decrypt
self.bytes_read = 0L
self.rawbuffer = ''
self.readbuffer = ''
self.linebuffer = ''
self.eof = False
self.univ_newlines = False
self.nlSeps = ("\n", )
self.lastdiscard = ''
self.compress_type = zipinfo.compress_type # Read from compressed files in 4k blocks.
self.compress_size = zipinfo.compress_size MIN_READ_SIZE = 4096
self.closed = False # Search for universal newlines or line chunks.
self.mode = "r" PATTERN = re.compile(r'^(?P<chunk>[^\r\n]+)|(?P<newline>\n|\r\n?)')
def __init__(self, fileobj, mode, zipinfo, decrypter=None):
self._fileobj = fileobj
self._decrypter = decrypter
self._orig_pos = fileobj.tell()
self._compress_type = zipinfo.compress_type
self._compress_size = zipinfo.compress_size
self._compress_left = zipinfo.compress_size
if self._compress_type == ZIP_DEFLATED:
self._decompressor = zlib.decompressobj(-15)
self._unconsumed = ''
self._readbuffer = ''
self._offset = 0
self._universal = 'U' in mode
self.newlines = None
# Adjust read size for encrypted files since the first 12 bytes
# are for the encryption/password information.
if self._decrypter is not None:
self._compress_left -= 12
self.mode = mode
self.name = zipinfo.filename self.name = zipinfo.filename
# read from compressed files in 64k blocks if hasattr(zipinfo, 'CRC'):
self.compreadsize = 64*1024 self._expected_crc = zipinfo.CRC
if self.compress_type == ZIP_DEFLATED: self._running_crc = crc32(b'') & 0xffffffff
self.dc = zlib.decompressobj(-15)
def set_univ_newlines(self, univ_newlines):
self.univ_newlines = univ_newlines
# pick line separator char(s) based on universal newlines flag
self.nlSeps = ("\n", )
if self.univ_newlines:
self.nlSeps = ("\r\n", "\r", "\n")
def __iter__(self):
return self
def next(self):
nextline = self.readline()
if not nextline:
raise StopIteration()
return nextline
def close(self):
self.closed = True
def _checkfornewline(self):
nl, nllen = -1, -1
if self.linebuffer:
# ugly check for cases where half of an \r\n pair was
# read on the last pass, and the \r was discarded. In this
# case we just throw away the \n at the start of the buffer.
if (self.lastdiscard, self.linebuffer[0]) == ('\r','\n'):
self.linebuffer = self.linebuffer[1:]
for sep in self.nlSeps:
nl = self.linebuffer.find(sep)
if nl >= 0:
nllen = len(sep)
return nl, nllen
return nl, nllen
def readline(self, size = -1):
"""Read a line with approx. size. If size is negative,
read a whole line.
"""
if size < 0:
size = sys.maxint
elif size == 0:
return ''
# check for a newline already in buffer
nl, nllen = self._checkfornewline()
if nl >= 0:
# the next line was already in the buffer
nl = min(nl, size)
else: else:
# no line break in buffer - try to read more self._expected_crc = None
size -= len(self.linebuffer)
while nl < 0 and size > 0:
buf = self.read(min(size, 100))
if not buf:
break
self.linebuffer += buf
size -= len(buf)
# check for a newline in buffer def readline(self, limit=-1):
nl, nllen = self._checkfornewline() """Read and return a line from the stream.
# we either ran out of bytes in the file, or If limit is specified, at most limit bytes will be read.
# met the specified size limit without finding a newline,
# so return current buffer
if nl < 0:
s = self.linebuffer
self.linebuffer = ''
return s
buf = self.linebuffer[:nl]
self.lastdiscard = self.linebuffer[nl:nl + nllen]
self.linebuffer = self.linebuffer[nl + nllen:]
# line is always returned with \n as newline char (except possibly
# for a final incomplete line in the file, which is handled above).
return buf + "\n"
def readlines(self, sizehint = -1):
"""Return a list with all (following) lines. The sizehint parameter
is ignored in this implementation.
""" """
result = []
if not self._universal and limit < 0:
# Shortcut common case - newline found in buffer.
i = self._readbuffer.find('\n', self._offset) + 1
if i > 0:
line = self._readbuffer[self._offset: i]
self._offset = i
return line
if not self._universal:
return io.BufferedIOBase.readline(self, limit)
line = ''
while limit < 0 or len(line) < limit:
readahead = self.peek(2)
if readahead == '':
return line
#
# Search for universal newlines or line chunks.
#
# The pattern returns either a line chunk or a newline, but not
# both. Combined with peek(2), we are assured that the sequence
# '\r\n' is always retrieved completely and never split into
# separate newlines - '\r', '\n' due to coincidental readaheads.
#
match = self.PATTERN.search(readahead)
newline = match.group('newline')
if newline is not None:
if self.newlines is None:
self.newlines = []
if newline not in self.newlines:
self.newlines.append(newline)
self._offset += len(newline)
return line + '\n'
chunk = match.group('chunk')
if limit >= 0:
chunk = chunk[: limit - len(line)]
self._offset += len(chunk)
line += chunk
return line
def peek(self, n=1):
"""Returns buffered bytes without advancing the position."""
if n > len(self._readbuffer) - self._offset:
chunk = self.read(n)
self._offset -= len(chunk)
# Return up to 512 bytes to reduce allocation overhead for tight loops.
return self._readbuffer[self._offset: self._offset + 512]
def readable(self):
return True
def read(self, n=-1):
"""Read and return up to n bytes.
If the argument is omitted, None, or negative, data is read and returned until EOF is reached..
"""
buf = ''
if n is None:
n = -1
while True: while True:
line = self.readline() if n < 0:
if not line: break data = self.read1(n)
result.append(line) elif n > len(buf):
return result data = self.read1(n - len(buf))
else:
return buf
if len(data) == 0:
return buf
buf += data
def _update_crc(self, newdata, eof):
# Update the CRC using the given data.
if self._expected_crc is None:
# No need to compute the CRC if we don't have a reference value
return
self._running_crc = crc32(newdata, self._running_crc) & 0xffffffff
# Check the CRC if we're at the end of the file
if eof and self._running_crc != self._expected_crc:
raise BadZipfile("Bad CRC-32 for file %r" % self.name)
def read1(self, n):
"""Read up to n bytes with at most one read() system call."""
# Simplify algorithm (branching) by transforming negative n to large n.
if n < 0 or n is None:
n = self.MAX_N
# Bytes available in read buffer.
len_readbuffer = len(self._readbuffer) - self._offset
# Read from file.
if self._compress_left > 0 and n > len_readbuffer + len(self._unconsumed):
nbytes = n - len_readbuffer - len(self._unconsumed)
nbytes = max(nbytes, self.MIN_READ_SIZE)
nbytes = min(nbytes, self._compress_left)
data = self._fileobj.read(nbytes)
self._compress_left -= len(data)
if data and self._decrypter is not None:
data = ''.join(map(self._decrypter, data))
if self._compress_type == ZIP_STORED:
self._update_crc(data, eof=(self._compress_left==0))
self._readbuffer = self._readbuffer[self._offset:] + data
self._offset = 0
else:
# Prepare deflated bytes for decompression.
self._unconsumed += data
# Handle unconsumed data.
if (len(self._unconsumed) > 0 and n > len_readbuffer and
self._compress_type == ZIP_DEFLATED):
data = self._decompressor.decompress(
self._unconsumed,
max(n - len_readbuffer, self.MIN_READ_SIZE)
)
self._unconsumed = self._decompressor.unconsumed_tail
eof = len(self._unconsumed) == 0 and self._compress_left == 0
if eof:
data += self._decompressor.flush()
self._update_crc(data, eof=eof)
self._readbuffer = self._readbuffer[self._offset:] + data
self._offset = 0
# Read from buffer.
data = self._readbuffer[self._offset: self._offset + n]
self._offset += len(data)
return data
def read_raw(self): def read_raw(self):
pos = self.fileobj.tell() pos = self._fileobj.tell()
self.fileobj.seek(self.orig_pos) self._fileobj.seek(self._orig_pos)
bytes_to_read = self.compress_size bytes_to_read = self._compress_size
if self.decrypter is not None: if self._decrypter is not None:
bytes_to_read -= 12 bytes_to_read -= 12
raw = b'' raw = b''
if bytes_to_read > 0: if bytes_to_read > 0:
raw = self.fileobj.read(bytes_to_read) raw = self._fileobj.read(bytes_to_read)
self.fileobj.seek(pos) self._fileobj.seek(pos)
return raw return raw
def read(self, size = None):
# act like file() obj and return empty string if size is 0
if size == 0:
return ''
# determine read size
bytesToRead = self.compress_size - self.bytes_read
# adjust read size for encrypted files since the first 12 bytes
# are for the encryption/password information
if self.decrypter is not None:
bytesToRead -= 12
if size is not None and size >= 0:
if self.compress_type == ZIP_STORED:
lr = len(self.readbuffer)
bytesToRead = min(bytesToRead, size - lr)
elif self.compress_type == ZIP_DEFLATED:
if len(self.readbuffer) > size:
# the user has requested fewer bytes than we've already
# pulled through the decompressor; don't read any more
bytesToRead = 0
else:
# user will use up the buffer, so read some more
lr = len(self.rawbuffer)
bytesToRead = min(bytesToRead, self.compreadsize - lr)
# avoid reading past end of file contents
if bytesToRead + self.bytes_read > self.compress_size:
bytesToRead = self.compress_size - self.bytes_read
# try to read from file (if necessary)
if bytesToRead > 0:
bytes = self.fileobj.read(bytesToRead)
self.bytes_read += len(bytes)
self.rawbuffer += bytes
# handle contents of raw buffer
if self.rawbuffer:
newdata = self.rawbuffer
self.rawbuffer = ''
# decrypt new data if we were given an object to handle that
if newdata and self.decrypter is not None:
newdata = ''.join(map(self.decrypter, newdata))
# decompress newly read data if necessary
if newdata and self.compress_type == ZIP_DEFLATED:
newdata = self.dc.decompress(newdata)
self.rawbuffer = self.dc.unconsumed_tail
if self.eof and len(self.rawbuffer) == 0:
# we're out of raw bytes (both from the file and
# the local buffer); flush just to make sure the
# decompressor is done
newdata += self.dc.flush()
# prevent decompressor from being used again
self.dc = None
self.readbuffer += newdata
# return what the user asked for
if size is None or len(self.readbuffer) <= size:
bytes = self.readbuffer
self.readbuffer = ''
else:
bytes = self.readbuffer[:size]
self.readbuffer = self.readbuffer[size:]
return bytes
class ZipFile: class ZipFile:
""" Class with methods to open, read, write, close, list and update zip files. """ Class with methods to open, read, write, close, list and update zip files.
@ -1053,16 +1047,7 @@ class ZipFile:
if ord(h[11]) != check_byte: if ord(h[11]) != check_byte:
raise RuntimeError("Bad password for file", name) raise RuntimeError("Bad password for file", name)
# build and return a ZipExtFile return ZipExtFile(zef_file, mode, zinfo, zd)
if zd is None:
zef = ZipExtFile(zef_file, zinfo)
else:
zef = ZipExtFile(zef_file, zinfo, zd)
# set universal newlines on ZipExtFile if necessary
if "U" in mode:
zef.set_univ_newlines(True)
return zef
def extract(self, member, path=None, pwd=None): def extract(self, member, path=None, pwd=None):
"""Extract a member from the archive to the current working directory, """Extract a member from the archive to the current working directory,
@ -1087,6 +1072,10 @@ class ZipFile:
if members is None: if members is None:
members = self.namelist() members = self.namelist()
# Kovid: Extract longer names first, just in case the zip file has
# an entry for a directory without a trailing slash
members.sort(key=len, reverse=True)
for zipinfo in members: for zipinfo in members:
self.extract(zipinfo, path, pwd) self.extract(zipinfo, path, pwd)
@ -1102,10 +1091,10 @@ class ZipFile:
targetpath = targetpath[:-1] targetpath = targetpath[:-1]
# don't include leading "/" from file name if present # don't include leading "/" from file name if present
fname = member.filename if member.filename[0] == '/':
if fname.startswith('/'): targetpath = os.path.join(targetpath, member.filename[1:])
fname = fname[1:] else:
targetpath = os.path.join(targetpath, fname) targetpath = os.path.join(targetpath, member.filename)
targetpath = os.path.normpath(targetpath) targetpath = os.path.normpath(targetpath)
@ -1114,7 +1103,15 @@ class ZipFile:
if upperdirs and not os.path.exists(upperdirs): if upperdirs and not os.path.exists(upperdirs):
os.makedirs(upperdirs) os.makedirs(upperdirs)
if not os.path.exists(targetpath): # Could be a previously automatically created directory if member.filename[-1] == '/':
if not os.path.isdir(targetpath):
os.mkdir(targetpath)
self.extract_mapping[member.filename] = targetpath
return targetpath
if not os.path.exists(targetpath):
# Kovid: Could be a previously automatically created directory
# in which case it is ignored
with closing(self.open(member, pwd=pwd)) as source: with closing(self.open(member, pwd=pwd)) as source:
try: try:
with open(targetpath, 'wb') as target: with open(targetpath, 'wb') as target: