diff --git a/src/calibre/ebooks/html_transform_rules.py b/src/calibre/ebooks/html_transform_rules.py
index 6230bed4c2..022d725b99 100644
--- a/src/calibre/ebooks/html_transform_rules.py
+++ b/src/calibre/ebooks/html_transform_rules.py
@@ -5,6 +5,7 @@
from functools import partial
from html5_parser import parse
+from lxml import etree
from calibre.ebooks.oeb.parse_utils import XHTML
from calibre.ebooks.oeb.base import OEB_DOCS, XPath
@@ -259,6 +260,47 @@ def wrap(data, tag):
return True
+def parse_html_snippet(text):
+ return parse(f'
{text}
', namespace_elements=True, fragment_context='div')[0]
+
+
+def clone(src_element, target_tree):
+ if src_element.tag is etree.Comment:
+ ans = etree.Comment('')
+ else:
+ ans = target_tree.makeelement(src_element.tag)
+ for k, v in src_element.items():
+ ans.set(k, v)
+ ans.extend(src_element)
+ ans.text = src_element.text
+ ans.tail = src_element.tail
+ return ans
+
+
+def insert_snippet(container, before_children, tag):
+ if before_children:
+ orig_text = tag.text
+ tag.text = container.text
+ if len(container):
+ for i, child in enumerate(reversed(container)):
+ c = clone(child, tag)
+ tag.insert(0, c)
+ if i == 0 and orig_text:
+ c.tail = (c.tail or '') + orig_text
+ else:
+ tag.text = (tag.text or '') + orig_text
+ else:
+ if container.text:
+ if len(tag) > 0:
+ tag[-1].tail = (tag[-1].tail or '') + container.text
+ else:
+ tag.text = (tag.text or '') + container.text
+ for child in container:
+ c = clone(child, tag)
+ tag.append(c)
+ return True
+
+
action_map = {
'rename': lambda data: partial(rename_tag, qualify_tag_name(data)),
'remove': lambda data: remove_tag,
@@ -269,6 +311,8 @@ action_map = {
'remove_attrs': lambda data: partial(remove_attrs, str.split(data)),
'add_attrs': lambda data: partial(add_attrs, parse_attrs(data)),
'wrap': lambda data: partial(wrap, parse_start_tag(data)),
+ 'insert': lambda data: partial(insert_snippet, parse_html_snippet(data), True),
+ 'insert_end': lambda data: partial(insert_snippet, parse_html_snippet(data), False),
}
@@ -426,7 +470,6 @@ def test(return_tests=False): # {{{
self.ae(rule, next(iter(import_rules(export_rules([rule])))))
def test_html_transform_actions(self):
- from lxml import etree
def r(html='hello'):
return parse(namespace_elements=True, html=html)[1]
@@ -502,6 +545,20 @@ def test(return_tests=False): # {{{
self.assertTrue(t('wrap', '
')(p))
ax(p.getparent(), '
tail')
+ p = r('
hellos')[0]
+ self.assertTrue(t('insert', 'texttail')(p))
+ ax(p, '
text
tail
hello
s')
+ p = r('
hellos')[0]
+ self.assertTrue(t('insert', 'text')(p))
+ ax(p, 'texthellos
')
+
+ p = r('hellos')[0]
+ self.assertTrue(t('insert_end', 'texttail')(p))
+ ax(p, '
hellostext
tail
')
+ p = r('
hellostail')[0]
+ self.assertTrue(t('insert_end', 'text')(p))
+ ax(p, '
hellostailtext
')
+
tests = unittest.defaultTestLoader.loadTestsFromTestCase(TestTransforms)
if return_tests:
return tests