Book container: Remove references to toc and raster cover from opf when those files are deleted.

This commit is contained in:
Kovid Goyal 2013-10-10 18:30:20 +05:30
parent f27b03e9e4
commit e7305f7004
2 changed files with 27 additions and 0 deletions

View File

@ -345,12 +345,23 @@ class Container(object): # {{{
self.remove_from_xml(elem) self.remove_from_xml(elem)
self.dirty(self.opf_name) self.dirty(self.opf_name)
if removed: if removed:
for spine in self.opf_xpath('//opf:spine'):
tocref = spine.attrib.get('toc', None)
if tocref and tocref in removed:
spine.attrib.pop('toc', None)
self.dirty(self.opf_name)
for item in self.opf_xpath('//opf:spine/opf:itemref[@idref]'): for item in self.opf_xpath('//opf:spine/opf:itemref[@idref]'):
idref = item.get('idref') idref = item.get('idref')
if idref in removed: if idref in removed:
self.remove_from_xml(item) self.remove_from_xml(item)
self.dirty(self.opf_name) self.dirty(self.opf_name)
for meta in self.opf_xpath('//opf:meta[@name="cover" and @content]'):
if meta.get('content') in removed:
self.remove_from_xml(meta)
self.dirty(self.opf_name)
for item in self.opf_xpath('//opf:guide/opf:reference[@href]'): for item in self.opf_xpath('//opf:guide/opf:reference[@href]'):
if self.href_to_name(item.get('href'), self.opf_name) == name: if self.href_to_name(item.get('href'), self.opf_name) == name:
self.remove_from_xml(item) self.remove_from_xml(item)

View File

@ -57,3 +57,19 @@ class ContainerTests(BaseTest):
for c in (c1, c2): for c in (c1, c2):
c.commit(outpath=x) c.commit(outpath=x)
def test_file_removal(self):
' Test removal of files from the container '
book = get_simple_book()
c = get_container(book, tdir=self.tdir)
files = ('toc.ncx', 'cover.png', 'titlepage.xhtml')
self.assertIn('titlepage.xhtml', {x[0] for x in c.spine_names})
self.assertTrue(c.opf_xpath('//opf:meta[@name="cover"]'))
for x in files:
c.remove_item(x)
self.assertIn(c.opf_name, c.dirtied)
self.assertNotIn('titlepage.xhtml', {x[0] for x in c.spine_names})
self.assertFalse(c.opf_xpath('//opf:meta[@name="cover"]'))
raw = c.serialize_item(c.opf_name).decode('utf-8')
for x in files:
self.assertNotIn(x, raw)