diff --git a/src/css_selectors/__init__.py b/src/css_selectors/__init__.py index 191e1dac64..81ed0f5b87 100644 --- a/src/css_selectors/__init__.py +++ b/src/css_selectors/__init__.py @@ -6,4 +6,8 @@ from __future__ import (unicode_literals, division, absolute_import, __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal ' +from css_selectors.parse import parse +from css_selectors.select import Select +from css_selectors.errors import SelectorError, SelectorSyntaxError, ExpressionError +__all__ = ['parse', 'Select', 'SelectorError', 'SelectorSyntaxError', 'ExpressionError'] diff --git a/src/css_selectors/select.py b/src/css_selectors/select.py index 5c2139855a..3da128fb67 100644 --- a/src/css_selectors/select.py +++ b/src/css_selectors/select.py @@ -80,6 +80,28 @@ def normalize_language_tag(tag): class Select(object): + ''' + + This class implements CSS Level 3 selectors on an lxml tree, with caching + for performance. To use: + + >>> select = Select(root) + >>> print(tuple(select('p.myclass'))) + + Tags are returned in document order. Note that attribute and tag names are + matched case-insensitively. Also namespaces are ignored (this is for + performance of the common case). + + WARNING: This class uses internal caches. You *must not* make any changes + to the lxml tree. If you do make some changes, either create a new Select + object or call :meth:`invalidate_caches`. + + This class can be easily sub-classes to work with tree implementations + other than lxml. Simply override the methods in the ``Tree Integration`` + block. + + ''' + combinator_mapping = { ' ': 'descendant', '>': 'child', @@ -106,6 +128,7 @@ class Select(object): self.dispatch_map = {k:trace_wrapper(v) for k, v in self.dispatch_map.iteritems()} def invalidate_caches(self): + 'Invalidate all caches. You must call this before using this object if you have made changes to the HTML tree' self._element_map = None self._id_map = None self._class_map = None @@ -114,6 +137,7 @@ class Select(object): self._lang_map = None def __call__(self, selector): + 'Return an iterator over all matching tags, in document order.' seen = set() for selector in get_parsed_selector(selector): parsed_selector = selector.parsed_tree @@ -140,7 +164,7 @@ class Select(object): def map_tag_name(x): return ascii_lower(x.rpartition('}')[2]) - for tag in self.root.iter('*'): + for tag in self.itertag(): em[map_tag_name(tag.tag)].add(tag) return self._element_map @@ -149,7 +173,7 @@ class Select(object): if self._id_map is None: self._id_map = im = defaultdict(OrderedSet) lower = ascii_lower - for elem in get_compiled_xpath('//*[@id]')(self.root): + for elem in self.iteridtags(): im[lower(elem.get('id'))].add(elem) return self._id_map @@ -158,7 +182,7 @@ class Select(object): if self._class_map is None: self._class_map = cm = defaultdict(OrderedSet) lower = ascii_lower - for elem in get_compiled_xpath('//*[@class]')(self.root): + for elem in self.iterclasstags(): for cls in elem.get('class').split(): cm[lower(cls)].add(elem) return self._class_map @@ -171,7 +195,7 @@ class Select(object): if '{' in self.root.tag: def map_attrib_name(x): return ascii_lower(x.rpartition('}')[2]) - for tag in self.root.iter('*'): + for tag in self.itertag(): for attr, val in tag.attrib.iteritems(): am[map_attrib_name(attr)][val].add(tag) return self._attrib_map @@ -184,7 +208,7 @@ class Select(object): if '{' in self.root.tag: def map_attrib_name(x): return ascii_lower(x.rpartition('}')[2]) - for tag in self.root.iter('*'): + for tag in self.itertag(): for attr, val in tag.attrib.iteritems(): for v in val.split(): am[map_attrib_name(attr)][v].add(tag) @@ -195,20 +219,40 @@ class Select(object): if self._lang_map is None: self._lang_map = lm = defaultdict(OrderedSet) dl = normalize_language_tag(self.default_lang) if self.default_lang else None - lmap = {tag:dl for tag in self.root.iter('*')} if dl else {} - for tag in self.root.iter('*'): + lmap = {tag:dl for tag in self.itertag()} if dl else {} + for tag in self.itertag(): lang = None for attr in ('{http://www.w3.org/XML/1998/namespace}lang', 'lang'): lang = tag.get(attr) if lang: lang = normalize_language_tag(lang) - for dtag in tag.iter('*'): + for dtag in self.itertag(tag): lmap[dtag] = lang for tag, langs in lmap.iteritems(): for lang in langs: lm[lang].add(tag) return self._lang_map + # Tree Integration {{{ + def itertag(self, tag=None): + return (self.root if tag is None else tag).iter('*') + + def iterdescendants(self, tag=None): + return (self.root if tag is None else tag).iterdescendants('*') + + def iterchildren(self, tag=None): + return (self.root if tag is None else tag).iterchildren('*') + + def itersiblings(self, tag=None, preceding=False): + return (self.root if tag is None else tag).itersiblings('*', preceding=preceding) + + def iteridtags(self): + return get_compiled_xpath('//*[@id]')(self.root) + + def iterclasstags(self): + return get_compiled_xpath('//*[@class]')(self.root) + # }}} + # Combinators {{{ def select_combinedselector(cache, combined): @@ -217,39 +261,39 @@ def select_combinedselector(cache, combined): # Fast path for when the sub-selector is all elements right = None if isinstance(combined.subselector, Element) and ( combined.subselector.element or '*') == '*' else cache.iterparsedselector(combined.subselector) - for item in cache.dispatch_map[combinator](cache.iterparsedselector(combined.selector), right): + for item in cache.dispatch_map[combinator](cache, cache.iterparsedselector(combined.selector), right): yield item -def select_descendant(left, right): +def select_descendant(cache, left, right): """right is a child, grand-child or further descendant of left""" right = always_in if right is None else frozenset(right) for ancestor in left: - for descendant in ancestor.iterdescendants('*'): + for descendant in cache.iterdescendants(ancestor): if descendant in right: yield descendant -def select_child(left, right): +def select_child(cache, left, right): """right is an immediate child of left""" right = always_in if right is None else frozenset(right) for parent in left: - for child in parent.iterchildren('*'): + for child in cache.iterchildren(parent): if child in right: yield child -def select_direct_adjacent(left, right): +def select_direct_adjacent(cache, left, right): """right is a sibling immediately after left""" right = always_in if right is None else frozenset(right) for parent in left: - for sibling in parent.itersiblings('*'): + for sibling in cache.itersiblings(parent): if sibling in right: yield sibling break -def select_indirect_adjacent(left, right): +def select_indirect_adjacent(cache, left, right): """right is a sibling after left, immediately or not""" right = always_in if right is None else frozenset(right) for parent in left: - for sibling in parent.itersiblings('*'): + for sibling in cache.itersiblings(parent): if sibling in right: yield sibling # }}} @@ -258,7 +302,7 @@ def select_element(cache, selector): """A type or universal selector.""" element = selector.element if not element or element == '*': - for elem in cache.root.iter('*'): + for elem in cache.itertag(): yield elem else: for elem in cache.element_map[ascii_lower(element)]: @@ -367,4 +411,4 @@ if __name__ == '__main__': from pprint import pprint root = etree.fromstring('

') select = Select(root, trace=True) - pprint(list(select(':lang(en)'))) + pprint(list(select('p a')))