Implement attribute selectors

This commit is contained in:
Kovid Goyal 2015-02-19 21:36:18 +05:30
parent f4dc77b839
commit 730ab1098e
2 changed files with 192 additions and 7 deletions

View File

@ -6,6 +6,7 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import re
from collections import OrderedDict, defaultdict
from functools import wraps
@ -20,6 +21,9 @@ parse_cache = OrderedDict()
XPATH_CACHE_SIZE = 30
xpath_cache = OrderedDict()
# Test that the string is not empty and does not contain whitespace
is_non_whitespace = re.compile(r'^[^ \t\r\n\f]+$').match
def get_parsed_selector(raw):
try:
return parse_cache[raw]
@ -69,7 +73,6 @@ class Select(object):
'^=': 'prefixmatch',
'$=': 'suffixmatch',
'*=': 'substringmatch',
'!=': 'different', # Not in Level 3 but I like it ;)
}
def __init__(self, root, dispatch_map=None, trace=False):
@ -83,12 +86,17 @@ class Select(object):
self._element_map = None
self._id_map = None
self._class_map = None
self._attrib_map = None
self._attrib_space_map = None
def __call__(self, selector):
seen = set()
for selector in get_parsed_selector(selector):
parsed_selector = selector.parsed_tree
for item in self.iterparsedselector(parsed_selector):
yield item
if item not in seen:
yield item
seen.add(item)
def iterparsedselector(self, parsed_selector):
type_name = type(parsed_selector).__name__
@ -108,7 +116,7 @@ class Select(object):
def map_tag_name(x):
return ascii_lower(x.rpartition('}')[2])
for tag in root.iter('*'):
for tag in self.root.iter('*'):
em[map_tag_name(tag.tag)].add(tag)
return self._element_map
@ -131,6 +139,34 @@ class Select(object):
cm[lower(cls)].add(elem)
return self._class_map
@property
def attrib_map(self):
if self._attrib_map is None:
self._attrib_map = am = defaultdict(lambda : defaultdict(OrderedSet))
map_attrib_name = ascii_lower
if '{' in self.root.tag:
def map_attrib_name(x):
return ascii_lower(x.rpartition('}')[2])
for tag in self.root.iter('*'):
for attr, val in tag.attrib.iteritems():
am[map_attrib_name(attr)][val].add(tag)
return self._attrib_map
@property
def attrib_space_map(self):
if self._attrib_space_map is None:
self._attrib_space_map = am = defaultdict(lambda : defaultdict(OrderedSet))
map_attrib_name = ascii_lower
if '{' in self.root.tag:
def map_attrib_name(x):
return ascii_lower(x.rpartition('}')[2])
for tag in self.root.iter('*'):
for attr, val in tag.attrib.iteritems():
for v in val.split():
am[map_attrib_name(attr)][v].add(tag)
return self._attrib_space_map
# Combinators {{{
def select_combinedselector(cache, combined):
@ -189,12 +225,10 @@ def select_element(cache, selector):
def select_hash(cache, selector):
'An id selector'
items = cache.id_map[ascii_lower(selector.id)]
if len(items) > 1:
if len(items) > 0:
for elem in cache.iterparsedselector(selector.selector):
if elem in items:
yield elem
elif items:
yield items[0]
def select_class(cache, selector):
'A class selector'
@ -204,10 +238,63 @@ def select_class(cache, selector):
if elem in items:
yield elem
# Attribute selectors {{{
def select_attrib(cache, selector):
operator = cache.attribute_operator_mapping[selector.operator]
items = frozenset(cache.dispatch_map[operator](cache, ascii_lower(selector.attrib), selector.value))
for item in cache.iterparsedselector(selector.selector):
if item in items:
yield item
def select_exists(cache, attrib, value=None):
for elem_set in cache.attrib_map[attrib].itervalues():
for elem in elem_set:
yield elem
def select_equals(cache, attrib, value):
for elem in cache.attrib_map[attrib][value]:
yield elem
def select_includes(cache, attrib, value):
if is_non_whitespace(value):
for elem in cache.attrib_space_map[attrib][value]:
yield elem
def select_dashmatch(cache, attrib, value):
if value:
for val, elem_set in cache.attrib_map[attrib].iteritems():
if val == value or val.startswith(value + '-'):
for elem in elem_set:
yield elem
def select_prefixmatch(cache, attrib, value):
if value:
for val, elem_set in cache.attrib_map[attrib].iteritems():
if val.startswith(value):
for elem in elem_set:
yield elem
def select_suffixmatch(cache, attrib, value):
if value:
for val, elem_set in cache.attrib_map[attrib].iteritems():
if val.endswith(value):
for elem in elem_set:
yield elem
def select_substringmatch(cache, attrib, value):
if value:
for val, elem_set in cache.attrib_map[attrib].iteritems():
if value in val:
for elem in elem_set:
yield elem
# }}}
default_dispatch_map = {name.partition('_')[2]:obj for name, obj in globals().items() if name.startswith('select_') and callable(obj)}
if __name__ == '__main__':
from pprint import pprint
root = etree.fromstring('<body xmlns="xxx"><p id="p" class="one two"><a id="a"/></p></body>')
select = Select(root, trace=True)
pprint(list(select('p#p.one.two')))
pprint(list(select('[class~=two]')))

View File

@ -8,11 +8,67 @@ __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import unittest, sys, argparse
from lxml import etree
from css_selectors.errors import SelectorSyntaxError
from css_selectors.parse import tokenize, parse
from css_selectors.select import Select
class TestCSSSelectors(unittest.TestCase):
# Test data {{{
HTML_IDS = '''
<html id="html"><head>
<link id="link-href" href="foo" />
<link id="link-nohref" />
</head><body>
<div id="outer-div">
<a id="name-anchor" name="foo"></a>
<a id="tag-anchor" rel="tag" href="http://localhost/foo">link</a>
<a id="nofollow-anchor" rel="nofollow" href="https://example.org">
link</a>
<ol id="first-ol" class="a b c">
<li id="first-li">content</li>
<li id="second-li" lang="En-us">
<div id="li-div">
</div>
</li>
<li id="third-li" class="ab c"></li>
<li id="fourth-li" class="ab
c"></li>
<li id="fifth-li"></li>
<li id="sixth-li"></li>
<li id="seventh-li"> </li>
</ol>
<p id="paragraph">
<b id="p-b">hi</b> <em id="p-em">there</em>
<b id="p-b2">guy</b>
<input type="checkbox" id="checkbox-unchecked" />
<input type="checkbox" id="checkbox-disabled" disabled="" />
<input type="text" id="text-checked" checked="checked" />
<input type="hidden" />
<input type="hidden" disabled="disabled" />
<input type="checkbox" id="checkbox-checked" checked="checked" />
<input type="checkbox" id="checkbox-disabled-checked"
disabled="disabled" checked="checked" />
<fieldset id="fieldset" disabled="disabled">
<input type="checkbox" id="checkbox-fieldset-disabled" />
<input type="hidden" />
</fieldset>
</p>
<ol id="second-ol">
</ol>
<map name="dummymap">
<area shape="circle" coords="200,250,25" href="foo.html" id="area-href" />
<area shape="default" id="area-nohref" />
</map>
</div>
<div id="foobar-div" foobar="ab bc
cde"><span id="foobar-span"></span></div>
</body></html>
'''
# }}}
ae = unittest.TestCase.assertEqual
def test_tokenizer(self): # {{{
@ -277,6 +333,48 @@ class TestCSSSelectors(unittest.TestCase):
"Got nested :not()")
# }}}
def test_select(self):
document = etree.fromstring(self.HTML_IDS)
select = Select(document)
def select_ids(selector):
for elem in select(selector):
yield elem.get('id') or 'nil'
def pcss(main, *selectors, **kwargs):
result = list(select_ids(main))
for selector in selectors:
self.ae(list(select_ids(selector)), result)
return result
all_ids = pcss('*')
self.ae(all_ids[:6], [
'html', 'nil', 'link-href', 'link-nohref', 'nil', 'outer-div'])
self.ae(all_ids[-1:], ['foobar-span'])
self.ae(pcss('div'), ['outer-div', 'li-div', 'foobar-div'])
self.ae(pcss('DIV'), [
'outer-div', 'li-div', 'foobar-div']) # case-insensitive in HTML
self.ae(pcss('div div'), ['li-div'])
self.ae(pcss('div, div div'), ['outer-div', 'li-div', 'foobar-div'])
self.ae(pcss('a[name]'), ['name-anchor'])
self.ae(pcss('a[NAme]'), ['name-anchor']) # case-insensitive in HTML:
self.ae(pcss('a[rel]'), ['tag-anchor', 'nofollow-anchor'])
self.ae(pcss('a[rel="tag"]'), ['tag-anchor'])
self.ae(pcss('a[href*="localhost"]'), ['tag-anchor'])
self.ae(pcss('a[href*=""]'), [])
self.ae(pcss('a[href^="http"]'), ['tag-anchor', 'nofollow-anchor'])
self.ae(pcss('a[href^="http:"]'), ['tag-anchor'])
self.ae(pcss('a[href^=""]'), [])
self.ae(pcss('a[href$="org"]'), ['nofollow-anchor'])
self.ae(pcss('a[href$=""]'), [])
self.ae(pcss('div[foobar~="bc"]', 'div[foobar~="cde"]'), ['foobar-div'])
self.ae(pcss('[foobar~="ab bc"]', '[foobar~=""]', '[foobar~=" \t"]'), [])
self.ae(pcss('div[foobar~="cd"]'), [])
self.ae(pcss('*[lang|="En"]', '[lang|="En-us"]'), ['second-li'])
# Attribute values are case sensitive
self.ae(pcss('*[lang|="en"]', '[lang|="en-US"]'), [])
self.ae(pcss('*[lang|="e"]'), [])
# Run tests {{{
def find_tests():
return unittest.defaultTestLoader.loadTestsFromTestCase(TestCSSSelectors)