diff --git a/src/calibre/ebooks/html_transform_rules.py b/src/calibre/ebooks/html_transform_rules.py index 6230bed4c2..022d725b99 100644 --- a/src/calibre/ebooks/html_transform_rules.py +++ b/src/calibre/ebooks/html_transform_rules.py @@ -5,6 +5,7 @@ from functools import partial from html5_parser import parse +from lxml import etree from calibre.ebooks.oeb.parse_utils import XHTML from calibre.ebooks.oeb.base import OEB_DOCS, XPath @@ -259,6 +260,47 @@ def wrap(data, tag): return True +def parse_html_snippet(text): + return parse(f'
{text}
', namespace_elements=True, fragment_context='div')[0] + + +def clone(src_element, target_tree): + if src_element.tag is etree.Comment: + ans = etree.Comment('') + else: + ans = target_tree.makeelement(src_element.tag) + for k, v in src_element.items(): + ans.set(k, v) + ans.extend(src_element) + ans.text = src_element.text + ans.tail = src_element.tail + return ans + + +def insert_snippet(container, before_children, tag): + if before_children: + orig_text = tag.text + tag.text = container.text + if len(container): + for i, child in enumerate(reversed(container)): + c = clone(child, tag) + tag.insert(0, c) + if i == 0 and orig_text: + c.tail = (c.tail or '') + orig_text + else: + tag.text = (tag.text or '') + orig_text + else: + if container.text: + if len(tag) > 0: + tag[-1].tail = (tag[-1].tail or '') + container.text + else: + tag.text = (tag.text or '') + container.text + for child in container: + c = clone(child, tag) + tag.append(c) + return True + + action_map = { 'rename': lambda data: partial(rename_tag, qualify_tag_name(data)), 'remove': lambda data: remove_tag, @@ -269,6 +311,8 @@ action_map = { 'remove_attrs': lambda data: partial(remove_attrs, str.split(data)), 'add_attrs': lambda data: partial(add_attrs, parse_attrs(data)), 'wrap': lambda data: partial(wrap, parse_start_tag(data)), + 'insert': lambda data: partial(insert_snippet, parse_html_snippet(data), True), + 'insert_end': lambda data: partial(insert_snippet, parse_html_snippet(data), False), } @@ -426,7 +470,6 @@ def test(return_tests=False): # {{{ self.ae(rule, next(iter(import_rules(export_rules([rule]))))) def test_html_transform_actions(self): - from lxml import etree def r(html='

hello'): return parse(namespace_elements=True, html=html)[1] @@ -502,6 +545,20 @@ def test(return_tests=False): # {{{ self.assertTrue(t('wrap', '

')(p)) ax(p.getparent(), '

ts

tail') + p = r('

hellos')[0] + self.assertTrue(t('insert', 'text

tail')(p)) + ax(p, '

text

tail
hellos

') + p = r('

hellos')[0] + self.assertTrue(t('insert', 'text')(p)) + ax(p, '

texthellos

') + + p = r('

hellos')[0] + self.assertTrue(t('insert_end', 'text

tail')(p)) + ax(p, '

hellostext

tail

') + p = r('

hellostail')[0] + self.assertTrue(t('insert_end', 'text')(p)) + ax(p, '

hellostailtext

') + tests = unittest.defaultTestLoader.loadTestsFromTestCase(TestTransforms) if return_tests: return tests