Restore LitReader refactoring (again)

This commit is contained in:
Marshall T. Vandegrift 2009-01-17 11:18:14 -05:00
parent 7449870919
commit bd296fa43c

View File

@ -7,20 +7,24 @@ __license__ = 'GPL v3'
__copyright__ = '2008, Kovid Goyal <kovid at kovidgoyal.net> ' \ __copyright__ = '2008, Kovid Goyal <kovid at kovidgoyal.net> ' \
'and Marshall T. Vandegrift <llasram@gmail.com>' 'and Marshall T. Vandegrift <llasram@gmail.com>'
import sys, struct, cStringIO, os import sys, struct, os
import functools import functools
import re import re
from urlparse import urldefrag from urlparse import urldefrag
from cStringIO import StringIO
from urllib import unquote as urlunquote
from lxml import etree from lxml import etree
from calibre.ebooks.lit import LitError from calibre.ebooks.lit import LitError
from calibre.ebooks.lit.maps import OPF_MAP, HTML_MAP from calibre.ebooks.lit.maps import OPF_MAP, HTML_MAP
import calibre.ebooks.lit.mssha1 as mssha1 import calibre.ebooks.lit.mssha1 as mssha1
from calibre.ebooks.oeb.base import urlnormalize from calibre.ebooks.oeb.base import XML_PARSER, urlnormalize
from calibre.ebooks import DRMError from calibre.ebooks import DRMError
from calibre import plugins from calibre import plugins
lzx, lxzerror = plugins['lzx'] lzx, lxzerror = plugins['lzx']
msdes, msdeserror = plugins['msdes'] msdes, msdeserror = plugins['msdes']
__all__ = ["LitReader"]
XML_DECL = """<?xml version="1.0" encoding="UTF-8" ?> XML_DECL = """<?xml version="1.0" encoding="UTF-8" ?>
""" """
OPF_DECL = """<?xml version="1.0" encoding="UTF-8" ?> OPF_DECL = """<?xml version="1.0" encoding="UTF-8" ?>
@ -108,6 +112,9 @@ def consume_sized_utf8_string(bytes, zpad=False):
pos += 1 pos += 1
return u''.join(result), bytes[pos:] return u''.join(result), bytes[pos:]
def encode(string):
return unicode(string).encode('ascii', 'xmlcharrefreplace')
class UnBinary(object): class UnBinary(object):
AMPERSAND_RE = re.compile( AMPERSAND_RE = re.compile(
r'&(?!(?:#[0-9]+|#x[0-9a-fA-F]+|[a-zA-Z_:][a-zA-Z0-9.-_:]+);)') r'&(?!(?:#[0-9]+|#x[0-9a-fA-F]+|[a-zA-Z_:][a-zA-Z0-9.-_:]+);)')
@ -118,13 +125,13 @@ class UnBinary(object):
def __init__(self, bin, path, manifest={}, map=HTML_MAP): def __init__(self, bin, path, manifest={}, map=HTML_MAP):
self.manifest = manifest self.manifest = manifest
self.tag_map, self.attr_map, self.tag_to_attr_map = map self.tag_map, self.attr_map, self.tag_to_attr_map = map
self.opf = map is OPF_MAP self.is_html = map is HTML_MAP
self.bin = bin
self.dir = os.path.dirname(path) self.dir = os.path.dirname(path)
self.buf = cStringIO.StringIO() buf = StringIO()
self.binary_to_text() self.binary_to_text(bin, buf)
self.raw = self.buf.getvalue().lstrip().decode('utf-8') self.raw = buf.getvalue().lstrip()
self.escape_reserved() self.escape_reserved()
self._tree = None
def escape_reserved(self): def escape_reserved(self):
raw = self.raw raw = self.raw
@ -151,18 +158,28 @@ class UnBinary(object):
return '/'.join(relpath) return '/'.join(relpath)
def __unicode__(self): def __unicode__(self):
return self.raw.decode('utf-8')
def __str__(self):
return self.raw return self.raw
def tree():
def fget(self):
if not self._tree:
self._tree = etree.fromstring(self.raw, parser=XML_PARSER)
return self._tree
return property(fget=fget)
tree = tree()
def binary_to_text(self, base=0, depth=0): def binary_to_text(self, bin, buf, index=0, depth=0):
tag_name = current_map = None tag_name = current_map = None
dynamic_tag = errors = 0 dynamic_tag = errors = 0
in_censorship = is_goingdown = False in_censorship = is_goingdown = False
state = 'text' state = 'text'
index = base
flags = 0 flags = 0
while index < len(self.bin): while index < len(bin):
c, index = read_utf8_char(self.bin, index) c, index = read_utf8_char(bin, index)
oc = ord(c) oc = ord(c)
if state == 'text': if state == 'text':
@ -175,7 +192,7 @@ class UnBinary(object):
c = '>>' c = '>>'
elif c == '<': elif c == '<':
c = '<<' c = '<<'
self.buf.write(c.encode('ascii', 'xmlcharrefreplace')) buf.write(encode(c))
elif state == 'get flags': elif state == 'get flags':
if oc == 0: if oc == 0:
@ -188,7 +205,7 @@ class UnBinary(object):
state = 'text' if oc == 0 else 'get attr' state = 'text' if oc == 0 else 'get attr'
if flags & FLAG_OPENING: if flags & FLAG_OPENING:
tag = oc tag = oc
self.buf.write('<') buf.write('<')
if not (flags & FLAG_CLOSING): if not (flags & FLAG_CLOSING):
is_goingdown = True is_goingdown = True
if tag == 0x8000: if tag == 0x8000:
@ -205,7 +222,7 @@ class UnBinary(object):
tag_name = '?'+unichr(tag)+'?' tag_name = '?'+unichr(tag)+'?'
current_map = self.tag_to_attr_map[tag] current_map = self.tag_to_attr_map[tag]
print 'WARNING: tag %s unknown' % unichr(tag) print 'WARNING: tag %s unknown' % unichr(tag)
self.buf.write(unicode(tag_name).encode('utf-8')) buf.write(encode(tag_name))
elif flags & FLAG_CLOSING: elif flags & FLAG_CLOSING:
if depth == 0: if depth == 0:
raise LitError('Extra closing tag') raise LitError('Extra closing tag')
@ -217,15 +234,14 @@ class UnBinary(object):
if not is_goingdown: if not is_goingdown:
tag_name = None tag_name = None
dynamic_tag = 0 dynamic_tag = 0
self.buf.write(' />') buf.write(' />')
else: else:
self.buf.write('>') buf.write('>')
index = self.binary_to_text(base=index, depth=depth+1) index = self.binary_to_text(bin, buf, index, depth+1)
is_goingdown = False is_goingdown = False
if not tag_name: if not tag_name:
raise LitError('Tag ends before it begins.') raise LitError('Tag ends before it begins.')
self.buf.write(u''.join( buf.write(encode(u''.join(('</', tag_name, '>'))))
('</', tag_name, '>')).encode('utf-8'))
dynamic_tag = 0 dynamic_tag = 0
tag_name = None tag_name = None
state = 'text' state = 'text'
@ -245,7 +261,7 @@ class UnBinary(object):
in_censorship = True in_censorship = True
state = 'get value length' state = 'get value length'
continue continue
self.buf.write(' ' + unicode(attr).encode('utf-8') + '=') buf.write(' ' + encode(attr) + '=')
if attr in ['href', 'src']: if attr in ['href', 'src']:
state = 'get href length' state = 'get href length'
else: else:
@ -253,40 +269,39 @@ class UnBinary(object):
elif state == 'get value length': elif state == 'get value length':
if not in_censorship: if not in_censorship:
self.buf.write('"') buf.write('"')
count = oc - 1 count = oc - 1
if count == 0: if count == 0:
if not in_censorship: if not in_censorship:
self.buf.write('"') buf.write('"')
in_censorship = False in_censorship = False
state = 'get attr' state = 'get attr'
continue continue
state = 'get value' state = 'get value'
if oc == 0xffff: if oc == 0xffff:
continue continue
if count < 0 or count > (len(self.bin) - index): if count < 0 or count > (len(bin) - index):
raise LitError('Invalid character count %d' % count) raise LitError('Invalid character count %d' % count)
elif state == 'get value': elif state == 'get value':
if count == 0xfffe: if count == 0xfffe:
if not in_censorship: if not in_censorship:
self.buf.write('%s"' % (oc - 1)) buf.write('%s"' % (oc - 1))
in_censorship = False in_censorship = False
state = 'get attr' state = 'get attr'
elif count > 0: elif count > 0:
if not in_censorship: if not in_censorship:
self.buf.write(c.encode( buf.write(encode(c))
'ascii', 'xmlcharrefreplace'))
count -= 1 count -= 1
if count == 0: if count == 0:
if not in_censorship: if not in_censorship:
self.buf.write('"') buf.write('"')
in_censorship = False in_censorship = False
state = 'get attr' state = 'get attr'
elif state == 'get custom length': elif state == 'get custom length':
count = oc - 1 count = oc - 1
if count <= 0 or count > len(self.bin)-index: if count <= 0 or count > len(bin)-index:
raise LitError('Invalid character count %d' % count) raise LitError('Invalid character count %d' % count)
dynamic_tag += 1 dynamic_tag += 1
state = 'get custom' state = 'get custom'
@ -296,26 +311,26 @@ class UnBinary(object):
tag_name += c tag_name += c
count -= 1 count -= 1
if count == 0: if count == 0:
self.buf.write(unicode(tag_name).encode('utf-8')) buf.write(encode(tag_name))
state = 'get attr' state = 'get attr'
elif state == 'get attr length': elif state == 'get attr length':
count = oc - 1 count = oc - 1
if count <= 0 or count > (len(self.bin) - index): if count <= 0 or count > (len(bin) - index):
raise LitError('Invalid character count %d' % count) raise LitError('Invalid character count %d' % count)
self.buf.write(' ') buf.write(' ')
state = 'get custom attr' state = 'get custom attr'
elif state == 'get custom attr': elif state == 'get custom attr':
self.buf.write(unicode(c).encode('utf-8')) buf.write(encode(c))
count -= 1 count -= 1
if count == 0: if count == 0:
self.buf.write('=') buf.write('=')
state = 'get value length' state = 'get value length'
elif state == 'get href length': elif state == 'get href length':
count = oc - 1 count = oc - 1
if count <= 0 or count > (len(self.bin) - index): if count <= 0 or count > (len(bin) - index):
raise LitError('Invalid character count %d' % count) raise LitError('Invalid character count %d' % count)
href = '' href = ''
state = 'get href' state = 'get href'
@ -329,10 +344,11 @@ class UnBinary(object):
if frag: if frag:
path = '#'.join((path, frag)) path = '#'.join((path, frag))
path = urlnormalize(path) path = urlnormalize(path)
self.buf.write((u'"%s"' % path).encode('utf-8')) buf.write(encode(u'"%s"' % path))
state = 'get attr' state = 'get attr'
return index return index
class DirectoryEntry(object): class DirectoryEntry(object):
def __init__(self, name, section, offset, size): def __init__(self, name, section, offset, size):
self.name = name self.name = name
@ -347,6 +363,7 @@ class DirectoryEntry(object):
def __str__(self): def __str__(self):
return repr(self) return repr(self)
class ManifestItem(object): class ManifestItem(object):
def __init__(self, original, internal, mime_type, offset, root, state): def __init__(self, original, internal, mime_type, offset, root, state):
self.original = original self.original = original
@ -374,65 +391,87 @@ class ManifestItem(object):
% (self.internal, self.path, self.mime_type, self.offset, % (self.internal, self.path, self.mime_type, self.offset,
self.root, self.state) self.root, self.state)
def preserve(function): def preserve(function):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
opos = self._stream.tell() opos = self.stream.tell()
try: try:
return function(self, *args, **kwargs) return function(self, *args, **kwargs)
finally: finally:
self._stream.seek(opos) self.stream.seek(opos)
functools.update_wrapper(wrapper, function) functools.update_wrapper(wrapper, function)
return wrapper return wrapper
class LitReader(object): class LitFile(object):
PIECE_SIZE = 16 PIECE_SIZE = 16
XML_PARSER = etree.XMLParser(
recover=True, resolve_entities=False) def __init__(self, filename_or_stream):
if hasattr(filename_or_stream, 'read'):
self.stream = filename_or_stream
else:
self.stream = open(filename_or_stream, 'rb')
try:
self.opf_path = os.path.splitext(
os.path.basename(self.stream.name))[0] + '.opf'
except AttributeError:
self.opf_path = 'content.opf'
if self.magic != 'ITOLITLS':
raise LitError('Not a valid LIT file')
if self.version != 1:
raise LitError('Unknown LIT version %d' % (self.version,))
self.read_secondary_header()
self.read_header_pieces()
self.read_section_names()
self.read_manifest()
self.read_drm()
def warn(self, msg):
print "WARNING: %s" % (msg,)
def magic(): def magic():
@preserve @preserve
def fget(self): def fget(self):
self._stream.seek(0) self.stream.seek(0)
return self._stream.read(8) return self.stream.read(8)
return property(fget=fget) return property(fget=fget)
magic = magic() magic = magic()
def version(): def version():
def fget(self): def fget(self):
self._stream.seek(8) self.stream.seek(8)
return u32(self._stream.read(4)) return u32(self.stream.read(4))
return property(fget=fget) return property(fget=fget)
version = version() version = version()
def hdr_len(): def hdr_len():
@preserve @preserve
def fget(self): def fget(self):
self._stream.seek(12) self.stream.seek(12)
return int32(self._stream.read(4)) return int32(self.stream.read(4))
return property(fget=fget) return property(fget=fget)
hdr_len = hdr_len() hdr_len = hdr_len()
def num_pieces(): def num_pieces():
@preserve @preserve
def fget(self): def fget(self):
self._stream.seek(16) self.stream.seek(16)
return int32(self._stream.read(4)) return int32(self.stream.read(4))
return property(fget=fget) return property(fget=fget)
num_pieces = num_pieces() num_pieces = num_pieces()
def sec_hdr_len(): def sec_hdr_len():
@preserve @preserve
def fget(self): def fget(self):
self._stream.seek(20) self.stream.seek(20)
return int32(self._stream.read(4)) return int32(self.stream.read(4))
return property(fget=fget) return property(fget=fget)
sec_hdr_len = sec_hdr_len() sec_hdr_len = sec_hdr_len()
def guid(): def guid():
@preserve @preserve
def fget(self): def fget(self):
self._stream.seek(24) self.stream.seek(24)
return self._stream.read(16) return self.stream.read(16)
return property(fget=fget) return property(fget=fget)
guid = guid() guid = guid()
@ -442,44 +481,27 @@ class LitReader(object):
size = self.hdr_len \ size = self.hdr_len \
+ (self.num_pieces * self.PIECE_SIZE) \ + (self.num_pieces * self.PIECE_SIZE) \
+ self.sec_hdr_len + self.sec_hdr_len
self._stream.seek(0) self.stream.seek(0)
return self._stream.read(size) return self.stream.read(size)
return property(fget=fget) return property(fget=fget)
header = header() header = header()
def __init__(self, filename_or_stream):
if hasattr(filename_or_stream, 'read'):
self._stream = filename_or_stream
else:
self._stream = open(filename_or_stream, 'rb')
if self.magic != 'ITOLITLS':
raise LitError('Not a valid LIT file')
if self.version != 1:
raise LitError('Unknown LIT version %d' % (self.version,))
self.entries = {}
self._read_secondary_header()
self._read_header_pieces()
self._read_section_names()
self._read_manifest()
self._read_meta()
self._read_drm()
@preserve @preserve
def __len__(self): def __len__(self):
self._stream.seek(0, 2) self.stream.seek(0, 2)
return self._stream.tell() return self.stream.tell()
@preserve @preserve
def _read_raw(self, offset, size): def read_raw(self, offset, size):
self._stream.seek(offset) self.stream.seek(offset)
return self._stream.read(size) return self.stream.read(size)
def _read_content(self, offset, size): def read_content(self, offset, size):
return self._read_raw(self.content_offset + offset, size) return self.read_raw(self.content_offset + offset, size)
def _read_secondary_header(self): def read_secondary_header(self):
offset = self.hdr_len + (self.num_pieces * self.PIECE_SIZE) offset = self.hdr_len + (self.num_pieces * self.PIECE_SIZE)
bytes = self._read_raw(offset, self.sec_hdr_len) bytes = self.read_raw(offset, self.sec_hdr_len)
offset = int32(bytes[4:]) offset = int32(bytes[4:])
while offset < len(bytes): while offset < len(bytes):
blocktype = bytes[offset:offset+4] blocktype = bytes[offset:offset+4]
@ -507,21 +529,21 @@ class LitReader(object):
if not hasattr(self, 'content_offset'): if not hasattr(self, 'content_offset'):
raise LitError('Could not figure out the content offset') raise LitError('Could not figure out the content offset')
def _read_header_pieces(self): def read_header_pieces(self):
src = self.header[self.hdr_len:] src = self.header[self.hdr_len:]
for i in xrange(self.num_pieces): for i in xrange(self.num_pieces):
piece = src[i * self.PIECE_SIZE:(i + 1) * self.PIECE_SIZE] piece = src[i * self.PIECE_SIZE:(i + 1) * self.PIECE_SIZE]
if u32(piece[4:]) != 0 or u32(piece[12:]) != 0: if u32(piece[4:]) != 0 or u32(piece[12:]) != 0:
raise LitError('Piece %s has 64bit value' % repr(piece)) raise LitError('Piece %s has 64bit value' % repr(piece))
offset, size = u32(piece), int32(piece[8:]) offset, size = u32(piece), int32(piece[8:])
piece = self._read_raw(offset, size) piece = self.read_raw(offset, size)
if i == 0: if i == 0:
continue # Dont need this piece continue # Dont need this piece
elif i == 1: elif i == 1:
if u32(piece[8:]) != self.entry_chunklen or \ if u32(piece[8:]) != self.entry_chunklen or \
u32(piece[12:]) != self.entry_unknown: u32(piece[12:]) != self.entry_unknown:
raise LitError('Secondary header does not match piece') raise LitError('Secondary header does not match piece')
self._read_directory(piece) self.read_directory(piece)
elif i == 2: elif i == 2:
if u32(piece[8:]) != self.count_chunklen or \ if u32(piece[8:]) != self.count_chunklen or \
u32(piece[12:]) != self.count_unknown: u32(piece[12:]) != self.count_unknown:
@ -532,12 +554,13 @@ class LitReader(object):
elif i == 4: elif i == 4:
self.piece4_guid = piece self.piece4_guid = piece
def _read_directory(self, piece): def read_directory(self, piece):
if not piece.startswith('IFCM'): if not piece.startswith('IFCM'):
raise LitError('Header piece #1 is not main directory.') raise LitError('Header piece #1 is not main directory.')
chunk_size, num_chunks = int32(piece[8:12]), int32(piece[24:28]) chunk_size, num_chunks = int32(piece[8:12]), int32(piece[24:28])
if (32 + (num_chunks * chunk_size)) != len(piece): if (32 + (num_chunks * chunk_size)) != len(piece):
raise LitError('IFCM HEADER has incorrect length') raise LitError('IFCM header has incorrect length')
self.entries = {}
for i in xrange(num_chunks): for i in xrange(num_chunks):
offset = 32 + (i * chunk_size) offset = 32 + (i * chunk_size)
chunk = piece[offset:offset + chunk_size] chunk = piece[offset:offset + chunk_size]
@ -571,17 +594,17 @@ class LitReader(object):
entry = DirectoryEntry(name, section, offset, size) entry = DirectoryEntry(name, section, offset, size)
self.entries[name] = entry self.entries[name] = entry
def _read_section_names(self): def read_section_names(self):
if '::DataSpace/NameList' not in self.entries: if '::DataSpace/NameList' not in self.entries:
raise LitError('Lit file does not have a valid NameList') raise LitError('Lit file does not have a valid NameList')
raw = self.get_file('::DataSpace/NameList') raw = self.get_file('::DataSpace/NameList')
if len(raw) < 4: if len(raw) < 4:
raise LitError('Invalid Namelist section') raise LitError('Invalid Namelist section')
pos = 4 pos = 4
self.num_sections = u16(raw[2:pos]) num_sections = u16(raw[2:pos])
self.section_names = [""]*self.num_sections self.section_names = [""] * num_sections
self.section_data = [None]*self.num_sections self.section_data = [None] * num_sections
for section in xrange(self.num_sections): for section in xrange(num_sections):
size = u16(raw[pos:pos+2]) size = u16(raw[pos:pos+2])
pos += 2 pos += 2
size = size*2 + 2 size = size*2 + 2
@ -591,11 +614,12 @@ class LitReader(object):
raw[pos:pos+size].decode('utf-16-le').rstrip('\000') raw[pos:pos+size].decode('utf-16-le').rstrip('\000')
pos += size pos += size
def _read_manifest(self): def read_manifest(self):
if '/manifest' not in self.entries: if '/manifest' not in self.entries:
raise LitError('Lit file does not have a valid manifest') raise LitError('Lit file does not have a valid manifest')
raw = self.get_file('/manifest') raw = self.get_file('/manifest')
self.manifest = {} self.manifest = {}
self.paths = {self.opf_path: None}
while raw: while raw:
slen, raw = ord(raw[0]), raw[1:] slen, raw = ord(raw[0]), raw[1:]
if slen == 0: break if slen == 0: break
@ -634,28 +658,9 @@ class LitReader(object):
for item in mlist: for item in mlist:
if item.path[0] == '/': if item.path[0] == '/':
item.path = os.path.basename(item.path) item.path = os.path.basename(item.path)
self.paths[item.path] = item
def _pretty_print(self, xml): def read_drm(self):
f = cStringIO.StringIO(xml.encode('utf-8'))
doc = etree.parse(f, parser=self.XML_PARSER)
pretty = etree.tostring(doc, encoding='ascii', pretty_print=True)
return XML_DECL + unicode(pretty)
def _read_meta(self):
path = 'content.opf'
raw = self.get_file('/meta')
xml = OPF_DECL
try:
xml += unicode(UnBinary(raw, path, self.manifest, OPF_MAP))
except LitError:
if 'PENGUIN group' not in raw: raise
print "WARNING: attempting PENGUIN malformed OPF fix"
raw = raw.replace(
'PENGUIN group', '\x00\x01\x18\x00PENGUIN group', 1)
xml += unicode(UnBinary(raw, path, self.manifest, OPF_MAP))
self.meta = xml
def _read_drm(self):
self.drmlevel = 0 self.drmlevel = 0
if '/DRMStorage/Licenses/EUL' in self.entries: if '/DRMStorage/Licenses/EUL' in self.entries:
self.drmlevel = 5 self.drmlevel = 5
@ -666,7 +671,7 @@ class LitReader(object):
else: else:
return return
if self.drmlevel < 5: if self.drmlevel < 5:
msdes.deskey(self._calculate_deskey(), msdes.DE1) msdes.deskey(self.calculate_deskey(), msdes.DE1)
bookkey = msdes.des(self.get_file('/DRMStorage/DRMSealed')) bookkey = msdes.des(self.get_file('/DRMStorage/DRMSealed'))
if bookkey[0] != '\000': if bookkey[0] != '\000':
raise LitError('Unable to decrypt title key!') raise LitError('Unable to decrypt title key!')
@ -674,7 +679,7 @@ class LitReader(object):
else: else:
raise DRMError("Cannot access DRM-protected book") raise DRMError("Cannot access DRM-protected book")
def _calculate_deskey(self): def calculate_deskey(self):
hashfiles = ['/meta', '/DRMStorage/DRMSource'] hashfiles = ['/meta', '/DRMStorage/DRMSource']
if self.drmlevel == 3: if self.drmlevel == 3:
hashfiles.append('/DRMStorage/DRMBookplate') hashfiles.append('/DRMStorage/DRMBookplate')
@ -698,18 +703,18 @@ class LitReader(object):
def get_file(self, name): def get_file(self, name):
entry = self.entries[name] entry = self.entries[name]
if entry.section == 0: if entry.section == 0:
return self._read_content(entry.offset, entry.size) return self.read_content(entry.offset, entry.size)
section = self.get_section(entry.section) section = self.get_section(entry.section)
return section[entry.offset:entry.offset+entry.size] return section[entry.offset:entry.offset+entry.size]
def get_section(self, section): def get_section(self, section):
data = self.section_data[section] data = self.section_data[section]
if not data: if not data:
data = self._get_section(section) data = self.get_section_uncached(section)
self.section_data[section] = data self.section_data[section] = data
return data return data
def _get_section(self, section): def get_section_uncached(self, section):
name = self.section_names[section] name = self.section_names[section]
path = '::DataSpace/Storage/' + name path = '::DataSpace/Storage/' + name
transform = self.get_file(path + '/Transform/List') transform = self.get_file(path + '/Transform/List')
@ -721,29 +726,29 @@ class LitReader(object):
raise LitError("ControlData is too short") raise LitError("ControlData is too short")
guid = msguid(transform) guid = msguid(transform)
if guid == DESENCRYPT_GUID: if guid == DESENCRYPT_GUID:
content = self._decrypt(content) content = self.decrypt(content)
control = control[csize:] control = control[csize:]
elif guid == LZXCOMPRESS_GUID: elif guid == LZXCOMPRESS_GUID:
reset_table = self.get_file( reset_table = self.get_file(
'/'.join(('::DataSpace/Storage', name, 'Transform', '/'.join(('::DataSpace/Storage', name, 'Transform',
LZXCOMPRESS_GUID, 'InstanceData/ResetTable'))) LZXCOMPRESS_GUID, 'InstanceData/ResetTable')))
content = self._decompress(content, control, reset_table) content = self.decompress(content, control, reset_table)
control = control[csize:] control = control[csize:]
else: else:
raise LitError("Unrecognized transform: %s." % repr(guid)) raise LitError("Unrecognized transform: %s." % repr(guid))
transform = transform[16:] transform = transform[16:]
return content return content
def _decrypt(self, content): def decrypt(self, content):
length = len(content) length = len(content)
extra = length & 0x7 extra = length & 0x7
if extra > 0: if extra > 0:
self._warn("content length not a multiple of block size") self.warn("content length not a multiple of block size")
content += "\0" * (8 - extra) content += "\0" * (8 - extra)
msdes.deskey(self.bookkey, msdes.DE1) msdes.deskey(self.bookkey, msdes.DE1)
return msdes.des(content) return msdes.des(content)
def _decompress(self, content, control, reset_table): def decompress(self, content, control, reset_table):
if len(control) < 32 or control[CONTROL_TAG:CONTROL_TAG+4] != "LZXC": if len(control) < 32 or control[CONTROL_TAG:CONTROL_TAG+4] != "LZXC":
raise LitError("Invalid ControlData tag value") raise LitError("Invalid ControlData tag value")
if len(reset_table) < (RESET_INTERVAL + 8): if len(reset_table) < (RESET_INTERVAL + 8):
@ -784,7 +789,7 @@ class LitReader(object):
result.append( result.append(
lzx.decompress(content[base:size], window_bytes)) lzx.decompress(content[base:size], window_bytes))
except lzx.LZXError: except lzx.LZXError:
self._warn("LZX decompression error; skipping chunk") self.warn("LZX decompression error; skipping chunk")
bytes_remaining -= window_bytes bytes_remaining -= window_bytes
base = size base = size
accum += int32(reset_table[RESET_INTERVAL:]) accum += int32(reset_table[RESET_INTERVAL:])
@ -794,55 +799,88 @@ class LitReader(object):
try: try:
result.append(lzx.decompress(content[base:], bytes_remaining)) result.append(lzx.decompress(content[base:], bytes_remaining))
except lzx.LZXError: except lzx.LZXError:
self._warn("LZX decompression error; skipping chunk") self.warn("LZX decompression error; skipping chunk")
bytes_remaining = 0 bytes_remaining = 0
if bytes_remaining > 0: if bytes_remaining > 0:
raise LitError("Failed to completely decompress section") raise LitError("Failed to completely decompress section")
return ''.join(result) return ''.join(result)
def get_entry_content(self, entry, pretty_print=False):
if 'spine' in entry.state:
name = '/'.join(('/data', entry.internal, 'content'))
path = entry.path
raw = self.get_file(name)
decl, map = (OPF_DECL, OPF_MAP) \
if name == '/meta' else (HTML_DECL, HTML_MAP)
content = decl + unicode(UnBinary(raw, path, self.manifest, map))
if pretty_print:
content = self._pretty_print(content)
content = content.encode('utf-8')
else:
name = '/'.join(('/data', entry.internal))
content = self.get_file(name)
return content
def extract_content(self, output_dir=os.getcwdu(), pretty_print=False):
output_dir = os.path.abspath(output_dir)
try:
opf_path = os.path.splitext(
os.path.basename(self._stream.name))[0] + '.opf'
except AttributeError:
opf_path = 'content.opf'
opf_path = os.path.join(output_dir, opf_path)
self._ensure_dir(opf_path)
with open(opf_path, 'wb') as f:
xml = self.meta
if pretty_print:
xml = self._pretty_print(xml)
f.write(xml.encode('utf-8'))
for entry in self.manifest.values():
path = os.path.join(output_dir, entry.path)
self._ensure_dir(path)
with open(path, 'wb') as f:
f.write(self.get_entry_content(entry, pretty_print))
class LitReader(object):
def __init__(self, filename_or_stream):
self._litfile = LitFile(filename_or_stream)
def namelist(self):
return self._litfile.paths.keys()
def exists(self, name):
return urlunquote(name) in self._litfile.paths
def read_xml(self, name):
entry = self._litfile.paths[urlunquote(name)] if name else None
if entry is None:
content = self._read_meta()
elif 'spine' in entry.state:
internal = '/'.join(('/data', entry.internal, 'content'))
raw = self._litfile.get_file(internal)
unbin = UnBinary(raw, name, self._litfile.manifest, HTML_MAP)
content = unbin.tree
else:
raise LitError('Requested non-XML content as XML')
return content
def read(self, name, pretty_print=False):
entry = self._litfile.paths[urlunquote(name)] if name else None
if entry is None:
meta = self._read_meta()
content = OPF_DECL + etree.tostring(
meta, encoding='ascii', pretty_print=pretty_print)
elif 'spine' in entry.state:
internal = '/'.join(('/data', entry.internal, 'content'))
raw = self._litfile.get_file(internal)
unbin = UnBinary(raw, name, self._litfile.manifest, HTML_MAP)
content = HTML_DECL
if pretty_print:
content += etree.tostring(unbin.tree,
encoding='ascii', pretty_print=True)
else:
content += str(unbin)
else:
internal = '/'.join(('/data', entry.internal))
content = self._litfile.get_file(internal)
return content
def meta():
def fget(self):
return self.read(self._litfile.opf_path)
return property(fget=fget)
meta = meta()
def _ensure_dir(self, path): def _ensure_dir(self, path):
dir = os.path.dirname(path) dir = os.path.dirname(path)
if not os.path.isdir(dir): if not os.path.isdir(dir):
os.makedirs(dir) os.makedirs(dir)
def extract_content(self, output_dir=os.getcwdu(), pretty_print=False):
for name in self.namelist():
path = os.path.join(output_dir, name)
self._ensure_dir(path)
with open(path, 'wb') as f:
f.write(self.read(name, pretty_print=pretty_print))
def _read_meta(self):
path = 'content.opf'
raw = self._litfile.get_file('/meta')
try:
unbin = UnBinary(raw, path, self._litfile.manifest, OPF_MAP)
except LitError:
if 'PENGUIN group' not in raw: raise
print "WARNING: attempting PENGUIN malformed OPF fix"
raw = raw.replace(
'PENGUIN group', '\x00\x01\x18\x00PENGUIN group', 1)
unbin = UnBinary(raw, path, self._litfile.manifest, OPF_MAP)
return unbin.tree
def _warn(self, msg):
print "WARNING: %s" % (msg,)
def option_parser(): def option_parser():
from calibre.utils.config import OptionParser from calibre.utils.config import OptionParser
@ -852,7 +890,8 @@ def option_parser():
help=_('Output directory. Defaults to current directory.')) help=_('Output directory. Defaults to current directory.'))
parser.add_option( parser.add_option(
'-p', '--pretty-print', default=False, action='store_true', '-p', '--pretty-print', default=False, action='store_true',
help=_('Legibly format extracted markup. May modify meaningful whitespace.')) help=_('Legibly format extracted markup.' \
' May modify meaningful whitespace.'))
parser.add_option( parser.add_option(
'--verbose', default=False, action='store_true', '--verbose', default=False, action='store_true',
help=_('Useful for debugging.')) help=_('Useful for debugging.'))