Add some documentation for css_selectors

Also allow the Select class to work with other tree implementations
This commit is contained in:
Kovid Goyal 2015-02-20 09:03:02 +05:30
parent 77726d774a
commit 6150a664c2
2 changed files with 67 additions and 19 deletions

View File

@ -6,4 +6,8 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
from css_selectors.parse import parse
from css_selectors.select import Select
from css_selectors.errors import SelectorError, SelectorSyntaxError, ExpressionError
__all__ = ['parse', 'Select', 'SelectorError', 'SelectorSyntaxError', 'ExpressionError']

View File

@ -80,6 +80,28 @@ def normalize_language_tag(tag):
class Select(object):
'''
This class implements CSS Level 3 selectors on an lxml tree, with caching
for performance. To use:
>>> select = Select(root)
>>> print(tuple(select('p.myclass')))
Tags are returned in document order. Note that attribute and tag names are
matched case-insensitively. Also namespaces are ignored (this is for
performance of the common case).
WARNING: This class uses internal caches. You *must not* make any changes
to the lxml tree. If you do make some changes, either create a new Select
object or call :meth:`invalidate_caches`.
This class can be easily sub-classes to work with tree implementations
other than lxml. Simply override the methods in the ``Tree Integration``
block.
'''
combinator_mapping = {
' ': 'descendant',
'>': 'child',
@ -106,6 +128,7 @@ class Select(object):
self.dispatch_map = {k:trace_wrapper(v) for k, v in self.dispatch_map.iteritems()}
def invalidate_caches(self):
'Invalidate all caches. You must call this before using this object if you have made changes to the HTML tree'
self._element_map = None
self._id_map = None
self._class_map = None
@ -114,6 +137,7 @@ class Select(object):
self._lang_map = None
def __call__(self, selector):
'Return an iterator over all matching tags, in document order.'
seen = set()
for selector in get_parsed_selector(selector):
parsed_selector = selector.parsed_tree
@ -140,7 +164,7 @@ class Select(object):
def map_tag_name(x):
return ascii_lower(x.rpartition('}')[2])
for tag in self.root.iter('*'):
for tag in self.itertag():
em[map_tag_name(tag.tag)].add(tag)
return self._element_map
@ -149,7 +173,7 @@ class Select(object):
if self._id_map is None:
self._id_map = im = defaultdict(OrderedSet)
lower = ascii_lower
for elem in get_compiled_xpath('//*[@id]')(self.root):
for elem in self.iteridtags():
im[lower(elem.get('id'))].add(elem)
return self._id_map
@ -158,7 +182,7 @@ class Select(object):
if self._class_map is None:
self._class_map = cm = defaultdict(OrderedSet)
lower = ascii_lower
for elem in get_compiled_xpath('//*[@class]')(self.root):
for elem in self.iterclasstags():
for cls in elem.get('class').split():
cm[lower(cls)].add(elem)
return self._class_map
@ -171,7 +195,7 @@ class Select(object):
if '{' in self.root.tag:
def map_attrib_name(x):
return ascii_lower(x.rpartition('}')[2])
for tag in self.root.iter('*'):
for tag in self.itertag():
for attr, val in tag.attrib.iteritems():
am[map_attrib_name(attr)][val].add(tag)
return self._attrib_map
@ -184,7 +208,7 @@ class Select(object):
if '{' in self.root.tag:
def map_attrib_name(x):
return ascii_lower(x.rpartition('}')[2])
for tag in self.root.iter('*'):
for tag in self.itertag():
for attr, val in tag.attrib.iteritems():
for v in val.split():
am[map_attrib_name(attr)][v].add(tag)
@ -195,20 +219,40 @@ class Select(object):
if self._lang_map is None:
self._lang_map = lm = defaultdict(OrderedSet)
dl = normalize_language_tag(self.default_lang) if self.default_lang else None
lmap = {tag:dl for tag in self.root.iter('*')} if dl else {}
for tag in self.root.iter('*'):
lmap = {tag:dl for tag in self.itertag()} if dl else {}
for tag in self.itertag():
lang = None
for attr in ('{http://www.w3.org/XML/1998/namespace}lang', 'lang'):
lang = tag.get(attr)
if lang:
lang = normalize_language_tag(lang)
for dtag in tag.iter('*'):
for dtag in self.itertag(tag):
lmap[dtag] = lang
for tag, langs in lmap.iteritems():
for lang in langs:
lm[lang].add(tag)
return self._lang_map
# Tree Integration {{{
def itertag(self, tag=None):
return (self.root if tag is None else tag).iter('*')
def iterdescendants(self, tag=None):
return (self.root if tag is None else tag).iterdescendants('*')
def iterchildren(self, tag=None):
return (self.root if tag is None else tag).iterchildren('*')
def itersiblings(self, tag=None, preceding=False):
return (self.root if tag is None else tag).itersiblings('*', preceding=preceding)
def iteridtags(self):
return get_compiled_xpath('//*[@id]')(self.root)
def iterclasstags(self):
return get_compiled_xpath('//*[@class]')(self.root)
# }}}
# Combinators {{{
def select_combinedselector(cache, combined):
@ -217,39 +261,39 @@ def select_combinedselector(cache, combined):
# Fast path for when the sub-selector is all elements
right = None if isinstance(combined.subselector, Element) and (
combined.subselector.element or '*') == '*' else cache.iterparsedselector(combined.subselector)
for item in cache.dispatch_map[combinator](cache.iterparsedselector(combined.selector), right):
for item in cache.dispatch_map[combinator](cache, cache.iterparsedselector(combined.selector), right):
yield item
def select_descendant(left, right):
def select_descendant(cache, left, right):
"""right is a child, grand-child or further descendant of left"""
right = always_in if right is None else frozenset(right)
for ancestor in left:
for descendant in ancestor.iterdescendants('*'):
for descendant in cache.iterdescendants(ancestor):
if descendant in right:
yield descendant
def select_child(left, right):
def select_child(cache, left, right):
"""right is an immediate child of left"""
right = always_in if right is None else frozenset(right)
for parent in left:
for child in parent.iterchildren('*'):
for child in cache.iterchildren(parent):
if child in right:
yield child
def select_direct_adjacent(left, right):
def select_direct_adjacent(cache, left, right):
"""right is a sibling immediately after left"""
right = always_in if right is None else frozenset(right)
for parent in left:
for sibling in parent.itersiblings('*'):
for sibling in cache.itersiblings(parent):
if sibling in right:
yield sibling
break
def select_indirect_adjacent(left, right):
def select_indirect_adjacent(cache, left, right):
"""right is a sibling after left, immediately or not"""
right = always_in if right is None else frozenset(right)
for parent in left:
for sibling in parent.itersiblings('*'):
for sibling in cache.itersiblings(parent):
if sibling in right:
yield sibling
# }}}
@ -258,7 +302,7 @@ def select_element(cache, selector):
"""A type or universal selector."""
element = selector.element
if not element or element == '*':
for elem in cache.root.iter('*'):
for elem in cache.itertag():
yield elem
else:
for elem in cache.element_map[ascii_lower(element)]:
@ -367,4 +411,4 @@ if __name__ == '__main__':
from pprint import pprint
root = etree.fromstring('<body xmlns="xxx" xml:lang="en"><p id="p" class="one two" lang="fr"><a id="a"/></p></body>')
select = Select(root, trace=True)
pprint(list(select(':lang(en)')))
pprint(list(select('p a')))