This commit is contained in:
Kovid Goyal 2019-03-31 19:40:00 +05:30
parent 4a3c9ca32f
commit 65146815d7
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
from __future__ import (unicode_literals, division, absolute_import, print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
@ -22,18 +21,22 @@ FOOTER_MAGIC = b'YZ'
DELTA_FILTER_ID = 0x03
LZMA2_FILTER_ID = 0x21
def align(raw):
extra = len(raw) % 4
if extra:
raw += b'\0' * (4 - extra)
return raw
def as_bytes(*args):
return bytes(bytearray(args))
def crc32(raw, start=0):
return 0xFFFFFFFF & _crc32(raw, start)
def decode_var_int(f):
ans, i, ch = 0, -1, 0x80
while ch >= 0x80:
@ -44,6 +47,7 @@ def decode_var_int(f):
ans |= (ch & 0x7f) << (i * 7)
return ans
def decode_var_int2(raw, pos):
ans, ch, opos = 0, 0x80, pos
while ch >= 0x80:
@ -54,6 +58,7 @@ def decode_var_int2(raw, pos):
pos += 1
return ans, pos
def encode_var_int(num):
if num == 0:
return b'\0'
@ -65,6 +70,7 @@ def encode_var_int(num):
buf[-1] &= 0x7F
return bytes(buf)
def read_stream_header(f):
try:
magic, stream_flags1, stream_flags2, crc = unpack(b'<6s2BI', f.read(12))
@ -81,6 +87,7 @@ def read_stream_header(f):
raise InvalidXZ('Stream flags header CRC incorrect')
return check_type
class CRCChecker(object):
def __init__(self, check_type):
@ -108,6 +115,7 @@ class CRCChecker(object):
def check(self, raw):
return self.code == unpack(self.fmt, raw)[0]
class Sha256Checker(object):
def __init__(self, *args):
@ -126,6 +134,7 @@ class Sha256Checker(object):
def check(self, raw):
return self.code == raw
class DummyChecker(object):
size = 0
@ -140,6 +149,7 @@ class DummyChecker(object):
def finish(self):
pass
class LZMA2Filter(object):
BUFSIZE = 10 # MB
@ -167,6 +177,7 @@ class LZMA2Filter(object):
def __call__(self, f, outfile, filters=()):
w = outfile.write
c = self.crc
def write(raw):
if filters:
raw = bytearray(raw)
@ -174,12 +185,16 @@ class LZMA2Filter(object):
raw = flt(raw)
raw = bytes(raw)
w(raw), c(raw)
try:
lzma.decompress2(f.read, f.seek, write, self.props, self.bufsize)
except lzma.error as e:
raise InvalidXZ('Failed to decode LZMA2 block with error code: %s' % e.message)
raise InvalidXZ(
'Failed to decode LZMA2 block with error code: %s' % e.message
)
self.crc.finish()
class DeltaFilter(object):
def __init__(self, props, *args):
@ -193,12 +208,15 @@ class DeltaFilter(object):
self.pos = lzma.delta_decode(raw, self.history, self.pos, self.distance)
return raw
def test_delta_filter():
raw = b'\xA1\xB1\x01\x02\x01\x02\x01\x02'
draw = b'\xA1\xB1\xA2\xB3\xA3\xB5\xA4\xB7'
def eq(s, d):
if s != d:
raise ValueError('%r != %r' % (s, d))
eq(draw, bytes(DeltaFilter(b'\x01')(bytearray(raw))))
f = DeltaFilter(b'\x01')
for ch, dch in zip(raw, draw):
@ -207,11 +225,14 @@ def test_delta_filter():
Block = namedtuple('Block', 'unpadded_size uncompressed_size')
def read_block_header(f, block_header_size_, check_type):
block_header_size = 4 * (ord(block_header_size_) + 1)
if block_header_size < 8:
raise InvalidXZ('Invalid block header size: %d' % block_header_size)
header, crc = unpack(b'<%dsI' % (block_header_size - 5), f.read(block_header_size - 1))
header, crc = unpack(
b'<%dsI' % (block_header_size - 5), f.read(block_header_size - 1)
)
if crc != crc32(block_header_size_ + header):
raise InvalidXZ('Block header CRC mismatch')
block_flags = ord(header[0])
@ -237,7 +258,7 @@ def read_block_header(f, block_header_size_, check_type):
raise InvalidXZ('Invalid filter id: %d' % filter_id)
if filter_id not in (LZMA2_FILTER_ID, DELTA_FILTER_ID):
raise InvalidXZ('Unsupported filter ID: 0x%x' % filter_id)
props = header[pos:pos+size_of_properties]
props = header[pos:pos + size_of_properties]
pos += size_of_properties
if len(props) != size_of_properties:
raise InvalidXZ('Incomplete filter properties')
@ -245,16 +266,22 @@ def read_block_header(f, block_header_size_, check_type):
raise InvalidXZ('LZMA2 filter must be the last filter')
elif filter_id == DELTA_FILTER_ID and not number_of_filters:
raise InvalidXZ('Delta filter cannot be the last filter')
filters.append((LZMA2Filter if filter_id == LZMA2_FILTER_ID else DeltaFilter)(props, check_type))
filters.append(
(LZMA2Filter
if filter_id == LZMA2_FILTER_ID else DeltaFilter)(props, check_type)
)
padding = header[pos:]
if padding.lstrip(b'\0'):
raise InvalidXZ('Non-null block header padding: %r' % padding)
filters.reverse()
return filters, compressed_size, uncompressed_size
def read_block(f, block_header_size_, check_type, outfile):
start_pos = f.tell() - 1
filters, compressed_size, uncompressed_size = read_block_header(f, block_header_size_, check_type)
filters, compressed_size, uncompressed_size = read_block_header(
f, block_header_size_, check_type
)
fpos, opos = f.tell(), outfile.tell()
filters[0](f, outfile, filters[1:])
actual_compressed_size = f.tell() - fpos
@ -277,6 +304,7 @@ def read_block(f, block_header_size_, check_type, outfile):
raise InvalidXZ('CRC for data does not match')
return Block(f.tell() - padding_count - start_pos, uncompressed_actual_size)
def read_index(f):
pos = f.tell() - 1
number_of_records = decode_var_int(f)
@ -298,6 +326,7 @@ def read_index(f):
if crc != crc32(raw):
raise InvalidXZ('Index field CRC mismatch')
def read_stream_footer(f, check_type, index_size):
crc, = unpack(b'<I', f.read(4))
raw = f.read(6)
@ -312,6 +341,7 @@ def read_stream_footer(f, check_type, index_size):
if crc != crc32(raw):
raise InvalidXZ('Stream footer CRC mismatch')
def read_stream(f, outfile):
check_type = read_stream_header(f)
blocks, index = [], None
@ -329,6 +359,7 @@ def read_stream(f, outfile):
raise InvalidXZ('Index does not match actual blocks in file')
read_stream_footer(f, check_type, index_size)
def decompress(raw, outfile=None):
'''
Decompress the specified data.
@ -359,6 +390,7 @@ def decompress(raw, outfile=None):
raise InvalidXZ('Found trailing garbage between streams')
return outfile
def compress(raw, outfile=None, level=5, check_type='crc64'):
'''
Compress the specified data into a .xz stream (which can be written directly as
@ -379,13 +411,22 @@ def compress(raw, outfile=None, level=5, check_type='crc64'):
# Write stream header
outfile.write(HEADER_MAGIC)
check_type = {'crc':1, 'crc32':1, 'sha256':0xa, None:0, '':0, 'none':0, 'None':0}.get(check_type, 4)
check_type = {
'crc': 1,
'crc32': 1,
'sha256': 0xa,
None: 0,
'': 0,
'none': 0,
'None': 0
}.get(check_type, 4)
stream_flags = as_bytes(0, check_type)
outfile.write(stream_flags)
outfile.write(pack(b'<I', crc32(stream_flags)))
# Write block header
filter_flags = encode_var_int(LZMA2_FILTER_ID) + encode_var_int(1) + lzma.preset_map[level]
filter_flags = encode_var_int(LZMA2_FILTER_ID
) + encode_var_int(1) + lzma.preset_map[level]
block_header = align(b'\0\0' + filter_flags)
bhs = ((4 + len(block_header)) // 4) - 1
block_header = as_bytes(bhs) + block_header[1:]
@ -394,14 +435,21 @@ def compress(raw, outfile=None, level=5, check_type='crc64'):
outfile.write(block_header)
# Write compressed data and check
checker = {0:DummyChecker, 1:CRCChecker, 4:CRCChecker, 0xa:Sha256Checker}[check_type](check_type)
checker = {
0: DummyChecker,
1: CRCChecker,
4: CRCChecker,
0xa: Sha256Checker
}[check_type](check_type)
uncompressed_size = [0]
def read(n):
ans = raw.read(n)
if ans:
uncompressed_size[0] += len(ans)
checker(ans)
return ans
lzma.compress(read, outfile.write, None, level)
unpadded_size = outfile.tell() - start
pos = outfile.tell()
@ -423,7 +471,9 @@ def compress(raw, outfile=None, level=5, check_type='crc64'):
# Write stream footer
backwards_size = pack(b'<I', ((len(index) + 4) // 4) - 1)
outfile.write(pack(b'<I', crc32(backwards_size + stream_flags)))
outfile.write(backwards_size), outfile.write(stream_flags), outfile.write(FOOTER_MAGIC)
outfile.write(backwards_size), outfile.write(stream_flags
), outfile.write(FOOTER_MAGIC)
def test_lzma2():
raw = P('template-functions.json', allow_user_override=False, data=True)
@ -435,6 +485,7 @@ def test_lzma2():
if obuf.getvalue() != raw:
raise ValueError('Roundtripping via LZMA2 failed')
def test_xz():
raw = P('template-functions.json', allow_user_override=False, data=True)
ibuf, obuf = BytesIO(raw), BytesIO()
@ -445,6 +496,7 @@ def test_xz():
if obuf.getvalue() != raw:
raise ValueError('Roundtripping via XZ failed')
if __name__ == '__main__':
import sys
decompress(open(sys.argv[-1], 'rb'))