Preserve attribute order when parsing

This commit is contained in:
Kovid Goyal 2013-10-26 14:57:42 +05:30
parent 12a581786b
commit dd676227b8
4 changed files with 153 additions and 118 deletions

View File

@ -9,6 +9,7 @@ __copyright__ = '2013, Kovid Goyal <kovid at kovidgoyal.net>'
import copy, re, warnings import copy, re, warnings
from functools import partial from functools import partial
from bisect import bisect from bisect import bisect
from collections import OrderedDict
from lxml.etree import ElementBase, XMLParser, ElementDefaultClassLookup, CommentBase from lxml.etree import ElementBase, XMLParser, ElementDefaultClassLookup, CommentBase
@ -203,7 +204,7 @@ def create_lxml_context():
# }}} # }}}
def process_attribs(attrs, nsmap): def process_attribs(attrs, nsmap):
attribs = {} attrib_name_map = {}
namespaced_attribs = {} namespaced_attribs = {}
xmlns = namespaces['xmlns'] xmlns = namespaces['xmlns']
for k, v in attrs.iteritems(): for k, v in attrs.iteritems():
@ -220,7 +221,7 @@ def process_attribs(attrs, nsmap):
if ns == xlink_ns: if ns == xlink_ns:
del nsmap[prefix] del nsmap[prefix]
nsmap['xlink'] = xlink_ns nsmap['xlink'] = xlink_ns
attribs['{%s}%s' % (k[2], k[1])] = v attrib_name_map[k] = '{%s}%s' % (k[2], k[1])
else: else:
if ':' in k: if ':' in k:
if k.startswith('xmlns') and (k.startswith('xmlns:') or k == 'xmlns'): if k.startswith('xmlns') and (k.startswith('xmlns:') or k == 'xmlns'):
@ -228,27 +229,46 @@ def process_attribs(attrs, nsmap):
if prefix is not None: if prefix is not None:
# Use an existing prefix for this namespace, if # Use an existing prefix for this namespace, if
# possible # possible
existing = {v:k for k, v in nsmap.iteritems()}.get(v, False) existing = {x:k for k, x in nsmap.iteritems()}.get(v, False)
if existing is not False: if existing is not False:
prefix = existing prefix = existing
nsmap[prefix] = v nsmap[prefix] = v
else: else:
namespaced_attribs[k] = v namespaced_attribs[k] = v
else: else:
attribs[k] = v attrib_name_map[k] = k
xml_lang = None
for k, v in namespaced_attribs.iteritems(): for k, v in namespaced_attribs.iteritems():
prefix, name = k.partition(':')[0::2] prefix, name = k.partition(':')[0::2]
if prefix == 'xml': if prefix == 'xml':
if name == 'lang': if name == 'lang':
attribs['lang'] = attribs.get('lang', v) xml_lang = v
continue continue
ns = nsmap.get(prefix, None) ns = nsmap.get(prefix, None)
if ns is not None: if ns is not None:
name = '{%s}%s' % (ns, name) name = '{%s}%s' % (ns, name)
attribs[name] =v attrib_name_map[k] = name
return attribs ans = OrderedDict((attrib_name_map.get(k, None), v) for k, v in attrs.iteritems())
ans.pop(None, None)
if xml_lang:
ans['lang'] = ans.get('lang', xml_lang)
return ans
def makeelement_ns(ctx, namespace, name, attrib, nsmap):
try:
elem = ctx.makeelement('{%s}%s' % (namespace, name), nsmap=nsmap)
except ValueError:
elem = ctx.makeelement('{%s}%s' % (namespace, to_xml_name(name)), nsmap=nsmap)
# Unfortunately, lxml randomizes attrib order if passed in the makeelement
# constructor, therefore they have to be set one by one.
for k, v in attrib.iteritems():
try:
elem.set(k, v)
except ValueError:
elem.set(to_xml_name(k), v)
return elem
class TreeBuilder(BaseTreeBuilder): class TreeBuilder(BaseTreeBuilder):
@ -285,11 +305,7 @@ class TreeBuilder(BaseTreeBuilder):
raise NamespacedHTMLPresent(name.rpartition(':')[0]) raise NamespacedHTMLPresent(name.rpartition(':')[0])
prefix, name = name.partition(':')[0::2] prefix, name = name.partition(':')[0::2]
namespace = nsmap.get(prefix, namespace) namespace = nsmap.get(prefix, namespace)
try: elem = makeelement_ns(self.lxml_context, namespace, name, attribs, nsmap)
elem = self.lxml_context.makeelement('{%s}%s' % (namespace, name), attrib=attribs, nsmap=nsmap)
except ValueError:
attribs = {to_xml_name(k):v for k, v in attribs.iteritems()}
elem = self.lxml_context.makeelement('{%s}%s' % (namespace, to_xml_name(name)), attrib=attribs, nsmap=nsmap)
# Ensure that svg and mathml elements get no namespace prefixes # Ensure that svg and mathml elements get no namespace prefixes
if elem.prefix is not None and namespace in known_namespaces: if elem.prefix is not None and namespace in known_namespaces:
@ -297,7 +313,10 @@ class TreeBuilder(BaseTreeBuilder):
if v == namespace: if v == namespace:
del nsmap[k] del nsmap[k]
nsmap[None] = namespace nsmap[None] = namespace
elem = self.lxml_context.makeelement(elem.tag, attrib=elem.attrib, nsmap=nsmap) nelem = self.lxml_context.makeelement(elem.tag, nsmap=nsmap)
for k, v in elem.items(): # Only elem.items() preserves attrib order
nelem.set(k, v)
elem = nelem
# Keep a reference to elem so that lxml does not delete and re-create # Keep a reference to elem so that lxml does not delete and re-create
# it, losing the name related attributes # it, losing the name related attributes
@ -385,12 +404,26 @@ class TreeBuilder(BaseTreeBuilder):
parent.appendChild(Comment(token["data"].replace('--', '- -'))) parent.appendChild(Comment(token["data"].replace('--', '- -')))
def process_namespace_free_attribs(attrs): def process_namespace_free_attribs(attrs):
attribs = {k:v for k, v in attrs.iteritems() if ':' not in k} anm = {k:k for k, v in attrs.iteritems() if ':' not in k}
for k in set(attrs) - set(attribs): for k in frozenset(attrs) - frozenset(anm):
prefix, name = k.partition(':')[0::2] prefix, name = k.partition(':')[0::2]
if prefix != 'xmlns' and name not in attribs: if prefix != 'xmlns' and name not in anm:
attribs[name] = attrs[k] anm[name] = k
return attribs ans = OrderedDict((anm.get(k, None), v) for k, v in attrs.iteritems())
ans.pop(None, None)
return ans
def makeelement(ctx, name, attrib):
try:
elem = ctx.makeelement(name)
except ValueError:
elem = ctx.makeelement(to_xml_name(name))
for k, v in attrib.iteritems():
try:
elem.set(k, v)
except ValueError:
elem.set(to_xml_name(k), v)
return elem
class NoNamespaceTreeBuilder(TreeBuilder): class NoNamespaceTreeBuilder(TreeBuilder):
@ -404,11 +437,7 @@ class NoNamespaceTreeBuilder(TreeBuilder):
def createElement(self, token, nsmap=None): def createElement(self, token, nsmap=None):
name = token['name'].rpartition(':')[2] name = token['name'].rpartition(':')[2]
attribs = process_namespace_free_attribs(token['data']) attribs = process_namespace_free_attribs(token['data'])
try: elem = makeelement(self.lxml_context, name, attribs)
elem = self.lxml_context.makeelement(name, attrib=attribs)
except ValueError:
attribs = {to_xml_name(k):v for k, v in attribs.iteritems()}
elem = self.lxml_context.makeelement(to_xml_name(name), attrib=attribs)
# Keep a reference to elem so that lxml does not delete and re-create # Keep a reference to elem so that lxml does not delete and re-create
# it, losing _namespace # it, losing _namespace
self.proxy_cache.append(elem) self.proxy_cache.append(elem)
@ -551,8 +580,7 @@ def parse(raw, decoder=None, log=None, discard_namespaces=False, line_numbers=Tr
if __name__ == '__main__': if __name__ == '__main__':
from lxml import etree from lxml import etree
root = parse('\n<html><head><title>a\n</title><p>&nbsp;\n<b>b', discard_namespaces=False) root = parse('\n<html><head><title>a\n</title><p b=1 c=2 a=0>&nbsp;\n<b>b<svg ass="wipe" viewbox="0">', discard_namespaces=False)
# root = parse('\n<html><p><svg viewbox="0 0 0 0"><image xlink:href="xxx"/><b></svg>&nbsp;\n<b>xxx', discard_namespaces=False)
print (etree.tostring(root, encoding='utf-8')) print (etree.tostring(root, encoding='utf-8'))
print() print()

View File

@ -135,7 +135,15 @@ def multiple_html_and_body(test, parse_function):
test.assertEqual(len(XPath('//h:html[@id and @lang]')(root)), 1, err) test.assertEqual(len(XPath('//h:html[@id and @lang]')(root)), 1, err)
test.assertEqual(len(XPath('//h:body[@id and @lang]')(root)), 1, err) test.assertEqual(len(XPath('//h:body[@id and @lang]')(root)), 1, err)
basic_checks = (nonvoid_cdata_elements, namespaces, space_characters, case_insensitive_element_names, entities, multiple_html_and_body) def attribute_replacement(test, parse_function):
markup = '<html><body><svg viewbox="0"></svg><svg xmlns="%s" viewbox="1">' % SVG_NS
root = parse_function(markup)
err = 'SVG attributes not normalized, parsed markup:\n' + etree.tostring(root)
test.assertEqual(len(XPath('//svg:svg[@viewBox]')(root)), 2, err)
basic_checks = (nonvoid_cdata_elements, namespaces, space_characters,
case_insensitive_element_names, entities,
multiple_html_and_body, attribute_replacement)
class ParsingTests(BaseTest): class ParsingTests(BaseTest):
@ -160,3 +168,10 @@ class ParsingTests(BaseTest):
elem = root.xpath('//*[local-name()="%s"]' % tag)[0] elem = root.xpath('//*[local-name()="%s"]' % tag)[0]
self.assertEqual(lnum, elem.sourceline, 'Line number incorrect for %s, source: %s:' % (tag, src)) self.assertEqual(lnum, elem.sourceline, 'Line number incorrect for %s, source: %s:' % (tag, src))
for ds in (False, True):
src = '\n<html>\n<p b=1 a=2 c=3 d=4 e=5 f=6 g=7 h=8><svg b=1 a=2 c=3 d=4 e=5 f=6 g=7 h=8>\n'
root = parse(src, discard_namespaces=ds)
for tag in ('p', 'svg'):
for i, (k, v) in enumerate(root.xpath('//*[local-name()="%s"]' % tag)[0].items()):
self.assertEqual(i+1, int(v))

View File

@ -433,6 +433,73 @@ mathmlTextIntegrationPointElements = frozenset((
(namespaces["mathml"], "mtext") (namespaces["mathml"], "mtext")
)) ))
adjustSVGAttributes = {
"attributename": "attributeName",
"attributetype": "attributeType",
"basefrequency": "baseFrequency",
"baseprofile": "baseProfile",
"calcmode": "calcMode",
"clippathunits": "clipPathUnits",
"contentscripttype": "contentScriptType",
"contentstyletype": "contentStyleType",
"diffuseconstant": "diffuseConstant",
"edgemode": "edgeMode",
"externalresourcesrequired": "externalResourcesRequired",
"filterres": "filterRes",
"filterunits": "filterUnits",
"glyphref": "glyphRef",
"gradienttransform": "gradientTransform",
"gradientunits": "gradientUnits",
"kernelmatrix": "kernelMatrix",
"kernelunitlength": "kernelUnitLength",
"keypoints": "keyPoints",
"keysplines": "keySplines",
"keytimes": "keyTimes",
"lengthadjust": "lengthAdjust",
"limitingconeangle": "limitingConeAngle",
"markerheight": "markerHeight",
"markerunits": "markerUnits",
"markerwidth": "markerWidth",
"maskcontentunits": "maskContentUnits",
"maskunits": "maskUnits",
"numoctaves": "numOctaves",
"pathlength": "pathLength",
"patterncontentunits": "patternContentUnits",
"patterntransform": "patternTransform",
"patternunits": "patternUnits",
"pointsatx": "pointsAtX",
"pointsaty": "pointsAtY",
"pointsatz": "pointsAtZ",
"preservealpha": "preserveAlpha",
"preserveaspectratio": "preserveAspectRatio",
"primitiveunits": "primitiveUnits",
"refx": "refX",
"refy": "refY",
"repeatcount": "repeatCount",
"repeatdur": "repeatDur",
"requiredextensions": "requiredExtensions",
"requiredfeatures": "requiredFeatures",
"specularconstant": "specularConstant",
"specularexponent": "specularExponent",
"spreadmethod": "spreadMethod",
"startoffset": "startOffset",
"stddeviation": "stdDeviation",
"stitchtiles": "stitchTiles",
"surfacescale": "surfaceScale",
"systemlanguage": "systemLanguage",
"tablevalues": "tableValues",
"targetx": "targetX",
"targety": "targetY",
"textlength": "textLength",
"viewbox": "viewBox",
"viewtarget": "viewTarget",
"xchannelselector": "xChannelSelector",
"ychannelselector": "yChannelSelector",
"zoomandpan": "zoomAndPan"
}
adjustMathMLAttributes = {"definitionurl": "definitionURL"}
adjustForeignAttributes = { adjustForeignAttributes = {
"xlink:actuate": ("xlink", "actuate", namespaces["xlink"]), "xlink:actuate": ("xlink", "actuate", namespaces["xlink"]),
"xlink:arcrole": ("xlink", "arcrole", namespaces["xlink"]), "xlink:arcrole": ("xlink", "arcrole", namespaces["xlink"]),

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import, division, unicode_literals
from six import with_metaclass from six import with_metaclass
import types import types
from collections import OrderedDict
from . import inputstream from . import inputstream
from . import tokenizer from . import tokenizer
@ -10,14 +11,12 @@ from . import treebuilders
from .treebuilders._base import Marker from .treebuilders._base import Marker
from . import utils from . import utils
from . import constants from .constants import (
from .constants import spaceCharacters, asciiUpper2Lower spaceCharacters, asciiUpper2Lower, specialElements, headingElements,
from .constants import specialElements cdataElements, rcdataElements, tokenTypes, tagTokenTypes, ReparseException, namespaces,
from .constants import headingElements htmlIntegrationPointElements, mathmlTextIntegrationPointElements,
from .constants import cdataElements, rcdataElements adjustForeignAttributes as adjustForeignAttributesMap, adjustSVGAttributes,
from .constants import tokenTypes, ReparseException, namespaces adjustMathMLAttributes)
from .constants import htmlIntegrationPointElements, mathmlTextIntegrationPointElements
from .constants import adjustForeignAttributes as adjustForeignAttributesMap
def parse(doc, treebuilder="etree", encoding=None, def parse(doc, treebuilder="etree", encoding=None,
@ -255,96 +254,18 @@ class HTMLParser(object):
""" HTML5 specific normalizations to the token stream """ """ HTML5 specific normalizations to the token stream """
if token["type"] == tokenTypes["StartTag"]: if token["type"] == tokenTypes["StartTag"]:
token["data"] = dict(token["data"][::-1]) token["data"] = OrderedDict(token['data'])
return token return token
def adjustMathMLAttributes(self, token): def adjustMathMLAttributes(self, token):
replacements = {"definitionurl": "definitionURL"} adjust_attributes(token, adjustMathMLAttributes)
for k, v in replacements.items():
if k in token["data"]:
token["data"][v] = token["data"][k]
del token["data"][k]
def adjustSVGAttributes(self, token): def adjustSVGAttributes(self, token):
replacements = { adjust_attributes(token, adjustSVGAttributes)
"attributename": "attributeName",
"attributetype": "attributeType",
"basefrequency": "baseFrequency",
"baseprofile": "baseProfile",
"calcmode": "calcMode",
"clippathunits": "clipPathUnits",
"contentscripttype": "contentScriptType",
"contentstyletype": "contentStyleType",
"diffuseconstant": "diffuseConstant",
"edgemode": "edgeMode",
"externalresourcesrequired": "externalResourcesRequired",
"filterres": "filterRes",
"filterunits": "filterUnits",
"glyphref": "glyphRef",
"gradienttransform": "gradientTransform",
"gradientunits": "gradientUnits",
"kernelmatrix": "kernelMatrix",
"kernelunitlength": "kernelUnitLength",
"keypoints": "keyPoints",
"keysplines": "keySplines",
"keytimes": "keyTimes",
"lengthadjust": "lengthAdjust",
"limitingconeangle": "limitingConeAngle",
"markerheight": "markerHeight",
"markerunits": "markerUnits",
"markerwidth": "markerWidth",
"maskcontentunits": "maskContentUnits",
"maskunits": "maskUnits",
"numoctaves": "numOctaves",
"pathlength": "pathLength",
"patterncontentunits": "patternContentUnits",
"patterntransform": "patternTransform",
"patternunits": "patternUnits",
"pointsatx": "pointsAtX",
"pointsaty": "pointsAtY",
"pointsatz": "pointsAtZ",
"preservealpha": "preserveAlpha",
"preserveaspectratio": "preserveAspectRatio",
"primitiveunits": "primitiveUnits",
"refx": "refX",
"refy": "refY",
"repeatcount": "repeatCount",
"repeatdur": "repeatDur",
"requiredextensions": "requiredExtensions",
"requiredfeatures": "requiredFeatures",
"specularconstant": "specularConstant",
"specularexponent": "specularExponent",
"spreadmethod": "spreadMethod",
"startoffset": "startOffset",
"stddeviation": "stdDeviation",
"stitchtiles": "stitchTiles",
"surfacescale": "surfaceScale",
"systemlanguage": "systemLanguage",
"tablevalues": "tableValues",
"targetx": "targetX",
"targety": "targetY",
"textlength": "textLength",
"viewbox": "viewBox",
"viewtarget": "viewTarget",
"xchannelselector": "xChannelSelector",
"ychannelselector": "yChannelSelector",
"zoomandpan": "zoomAndPan"
}
for originalName in list(token["data"].keys()):
if originalName in replacements:
svgName = replacements[originalName]
token["data"][svgName] = token["data"][originalName]
del token["data"][originalName]
def adjustForeignAttributes(self, token): def adjustForeignAttributes(self, token):
replacements = adjustForeignAttributesMap adjust_attributes(token, adjustForeignAttributesMap)
for originalName in token["data"].keys():
if originalName in replacements:
foreignName = replacements[originalName]
token["data"][foreignName] = token["data"][originalName]
del token["data"][originalName]
def reparseTokenNormal(self, token): def reparseTokenNormal(self, token):
self.parser.phase() self.parser.phase()
@ -424,7 +345,7 @@ def getPhases(debug):
def log(function): def log(function):
"""Logger that records which phase processes each token""" """Logger that records which phase processes each token"""
type_names = dict((value, key) for key, value in type_names = dict((value, key) for key, value in
constants.tokenTypes.items()) tokenTypes.items())
def wrapped(self, *args, **kwargs): def wrapped(self, *args, **kwargs):
if function.__name__.startswith("process") and len(args) > 0: if function.__name__.startswith("process") and len(args) > 0:
@ -433,7 +354,7 @@ def getPhases(debug):
info = {"type": type_names[token['type']]} info = {"type": type_names[token['type']]}
except: except:
raise raise
if token['type'] in constants.tagTokenTypes: if token['type'] in tagTokenTypes:
info["name"] = token['name'] info["name"] = token['name']
self.parser.log.append((self.parser.tokenizer.state.__name__, self.parser.log.append((self.parser.tokenizer.state.__name__,
@ -2721,6 +2642,10 @@ def getPhases(debug):
# XXX after after frameset # XXX after after frameset
} }
def adjust_attributes(token, replacements):
if frozenset(token['data']) & frozenset(replacements):
token['data'] = OrderedDict(
(replacements.get(k, k), v) for k, v in token['data'].iteritems())
class ParseError(Exception): class ParseError(Exception):