From 730ab1098e38522acde8ef0f02295a6da5227335 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Thu, 19 Feb 2015 21:36:18 +0530 Subject: [PATCH] Implement attribute selectors --- src/css_selectors/select.py | 101 +++++++++++++++++++++++++++++++++--- src/css_selectors/tests.py | 98 ++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 7 deletions(-) diff --git a/src/css_selectors/select.py b/src/css_selectors/select.py index 77663364e8..842498dc5c 100644 --- a/src/css_selectors/select.py +++ b/src/css_selectors/select.py @@ -6,6 +6,7 @@ from __future__ import (unicode_literals, division, absolute_import, __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal ' +import re from collections import OrderedDict, defaultdict from functools import wraps @@ -20,6 +21,9 @@ parse_cache = OrderedDict() XPATH_CACHE_SIZE = 30 xpath_cache = OrderedDict() +# Test that the string is not empty and does not contain whitespace +is_non_whitespace = re.compile(r'^[^ \t\r\n\f]+$').match + def get_parsed_selector(raw): try: return parse_cache[raw] @@ -69,7 +73,6 @@ class Select(object): '^=': 'prefixmatch', '$=': 'suffixmatch', '*=': 'substringmatch', - '!=': 'different', # Not in Level 3 but I like it ;) } def __init__(self, root, dispatch_map=None, trace=False): @@ -83,12 +86,17 @@ class Select(object): self._element_map = None self._id_map = None self._class_map = None + self._attrib_map = None + self._attrib_space_map = None def __call__(self, selector): + seen = set() for selector in get_parsed_selector(selector): parsed_selector = selector.parsed_tree for item in self.iterparsedselector(parsed_selector): - yield item + if item not in seen: + yield item + seen.add(item) def iterparsedselector(self, parsed_selector): type_name = type(parsed_selector).__name__ @@ -108,7 +116,7 @@ class Select(object): def map_tag_name(x): return ascii_lower(x.rpartition('}')[2]) - for tag in root.iter('*'): + for tag in self.root.iter('*'): em[map_tag_name(tag.tag)].add(tag) return self._element_map @@ -131,6 +139,34 @@ class Select(object): cm[lower(cls)].add(elem) return self._class_map + @property + def attrib_map(self): + if self._attrib_map is None: + self._attrib_map = am = defaultdict(lambda : defaultdict(OrderedSet)) + map_attrib_name = ascii_lower + if '{' in self.root.tag: + def map_attrib_name(x): + return ascii_lower(x.rpartition('}')[2]) + for tag in self.root.iter('*'): + for attr, val in tag.attrib.iteritems(): + am[map_attrib_name(attr)][val].add(tag) + return self._attrib_map + + @property + def attrib_space_map(self): + if self._attrib_space_map is None: + self._attrib_space_map = am = defaultdict(lambda : defaultdict(OrderedSet)) + map_attrib_name = ascii_lower + if '{' in self.root.tag: + def map_attrib_name(x): + return ascii_lower(x.rpartition('}')[2]) + for tag in self.root.iter('*'): + for attr, val in tag.attrib.iteritems(): + for v in val.split(): + am[map_attrib_name(attr)][v].add(tag) + return self._attrib_space_map + + # Combinators {{{ def select_combinedselector(cache, combined): @@ -189,12 +225,10 @@ def select_element(cache, selector): def select_hash(cache, selector): 'An id selector' items = cache.id_map[ascii_lower(selector.id)] - if len(items) > 1: + if len(items) > 0: for elem in cache.iterparsedselector(selector.selector): if elem in items: yield elem - elif items: - yield items[0] def select_class(cache, selector): 'A class selector' @@ -204,10 +238,63 @@ def select_class(cache, selector): if elem in items: yield elem +# Attribute selectors {{{ + +def select_attrib(cache, selector): + operator = cache.attribute_operator_mapping[selector.operator] + items = frozenset(cache.dispatch_map[operator](cache, ascii_lower(selector.attrib), selector.value)) + for item in cache.iterparsedselector(selector.selector): + if item in items: + yield item + +def select_exists(cache, attrib, value=None): + for elem_set in cache.attrib_map[attrib].itervalues(): + for elem in elem_set: + yield elem + +def select_equals(cache, attrib, value): + for elem in cache.attrib_map[attrib][value]: + yield elem + +def select_includes(cache, attrib, value): + if is_non_whitespace(value): + for elem in cache.attrib_space_map[attrib][value]: + yield elem + +def select_dashmatch(cache, attrib, value): + if value: + for val, elem_set in cache.attrib_map[attrib].iteritems(): + if val == value or val.startswith(value + '-'): + for elem in elem_set: + yield elem + +def select_prefixmatch(cache, attrib, value): + if value: + for val, elem_set in cache.attrib_map[attrib].iteritems(): + if val.startswith(value): + for elem in elem_set: + yield elem + +def select_suffixmatch(cache, attrib, value): + if value: + for val, elem_set in cache.attrib_map[attrib].iteritems(): + if val.endswith(value): + for elem in elem_set: + yield elem + +def select_substringmatch(cache, attrib, value): + if value: + for val, elem_set in cache.attrib_map[attrib].iteritems(): + if value in val: + for elem in elem_set: + yield elem + +# }}} + default_dispatch_map = {name.partition('_')[2]:obj for name, obj in globals().items() if name.startswith('select_') and callable(obj)} if __name__ == '__main__': from pprint import pprint root = etree.fromstring('

') select = Select(root, trace=True) - pprint(list(select('p#p.one.two'))) + pprint(list(select('[class~=two]'))) diff --git a/src/css_selectors/tests.py b/src/css_selectors/tests.py index d6dd74f270..bfe871ebca 100644 --- a/src/css_selectors/tests.py +++ b/src/css_selectors/tests.py @@ -8,11 +8,67 @@ __copyright__ = '2015, Kovid Goyal ' import unittest, sys, argparse +from lxml import etree + from css_selectors.errors import SelectorSyntaxError from css_selectors.parse import tokenize, parse +from css_selectors.select import Select class TestCSSSelectors(unittest.TestCase): + # Test data {{{ + HTML_IDS = ''' + + + + +
+ + + + link +
    +
  1. content
  2. +
  3. +
    +
    +
  4. +
  5. +
  6. +
  7. +
  8. +
  9. +
+

+ hi there + guy + + + + + + + +

+ + +
+

+
    +
+ + + + +
+
+ +''' +# }}} + ae = unittest.TestCase.assertEqual def test_tokenizer(self): # {{{ @@ -277,6 +333,48 @@ class TestCSSSelectors(unittest.TestCase): "Got nested :not()") # }}} + def test_select(self): + document = etree.fromstring(self.HTML_IDS) + select = Select(document) + + def select_ids(selector): + for elem in select(selector): + yield elem.get('id') or 'nil' + + def pcss(main, *selectors, **kwargs): + result = list(select_ids(main)) + for selector in selectors: + self.ae(list(select_ids(selector)), result) + return result + all_ids = pcss('*') + self.ae(all_ids[:6], [ + 'html', 'nil', 'link-href', 'link-nohref', 'nil', 'outer-div']) + self.ae(all_ids[-1:], ['foobar-span']) + self.ae(pcss('div'), ['outer-div', 'li-div', 'foobar-div']) + self.ae(pcss('DIV'), [ + 'outer-div', 'li-div', 'foobar-div']) # case-insensitive in HTML + self.ae(pcss('div div'), ['li-div']) + self.ae(pcss('div, div div'), ['outer-div', 'li-div', 'foobar-div']) + self.ae(pcss('a[name]'), ['name-anchor']) + self.ae(pcss('a[NAme]'), ['name-anchor']) # case-insensitive in HTML: + self.ae(pcss('a[rel]'), ['tag-anchor', 'nofollow-anchor']) + self.ae(pcss('a[rel="tag"]'), ['tag-anchor']) + self.ae(pcss('a[href*="localhost"]'), ['tag-anchor']) + self.ae(pcss('a[href*=""]'), []) + self.ae(pcss('a[href^="http"]'), ['tag-anchor', 'nofollow-anchor']) + self.ae(pcss('a[href^="http:"]'), ['tag-anchor']) + self.ae(pcss('a[href^=""]'), []) + self.ae(pcss('a[href$="org"]'), ['nofollow-anchor']) + self.ae(pcss('a[href$=""]'), []) + self.ae(pcss('div[foobar~="bc"]', 'div[foobar~="cde"]'), ['foobar-div']) + self.ae(pcss('[foobar~="ab bc"]', '[foobar~=""]', '[foobar~=" \t"]'), []) + self.ae(pcss('div[foobar~="cd"]'), []) + self.ae(pcss('*[lang|="En"]', '[lang|="En-us"]'), ['second-li']) + # Attribute values are case sensitive + self.ae(pcss('*[lang|="en"]', '[lang|="en-US"]'), []) + self.ae(pcss('*[lang|="e"]'), []) + + # Run tests {{{ def find_tests(): return unittest.defaultTestLoader.loadTestsFromTestCase(TestCSSSelectors)