Subsequence matcher: Use primary collation

This commit is contained in:
Kovid Goyal 2014-03-08 10:59:19 +05:30
parent 6e9afc0398
commit b672f4ed11
2 changed files with 116 additions and 25 deletions

View File

@ -171,12 +171,13 @@ static void convert_positions(int32_t *positions, int32_t *final_positions, UCha
} }
} }
static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions) { static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions, UStringSearch **searches) {
UChar32 nc, hc, lc; UChar32 hc, lc;
UChar *p;
double final_score = 0.0, score = 0.0, score_for_char = 0.0; double final_score = 0.0, score = 0.0, score_for_char = 0.0;
int32_t pos, i, j, hidx, nidx, last_idx, distance, *positions = final_positions + m->needle_len; int32_t pos, i, j, hidx, nidx, last_idx, distance, *positions = final_positions + m->needle_len;
MemoryItem mem = {0}; MemoryItem mem = {0};
UStringSearch *search = NULL;
UErrorCode status = U_ZERO_ERROR;
stack_push(stack, 0, 0, 0, 0.0, final_positions); stack_push(stack, 0, 0, 0, 0.0, final_positions);
@ -187,11 +188,14 @@ static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions)
// No memoized result, calculate the score // No memoized result, calculate the score
for (i = nidx; i < m->needle_len;) { for (i = nidx; i < m->needle_len;) {
nidx = i; nidx = i;
U16_NEXT(m->needle, i, m->needle_len, nc); // i now points to next char in needle U16_FWD_1(m->needle, i, m->needle_len);// i now points to next char in needle
if (m->haystack_len - hidx < m->needle_len - nidx) { score = 0.0; break; } search = searches[nidx];
p = u_strchr32(m->haystack + hidx, nc); // TODO: Use primary collation for the find if (search == NULL || m->haystack_len - hidx < m->needle_len - nidx) { score = 0.0; break; }
if (p == NULL) { score = 0.0; break; } status = U_ZERO_ERROR; // We ignore any errors as we already know that hidx is correct
pos = (int32_t)(p - m->haystack); usearch_setOffset(search, hidx, &status);
status = U_ZERO_ERROR;
pos = usearch_next(search, &status);
if (pos == USEARCH_DONE) { score = 0.0; break; } // No matches found
distance = u_countChar32(m->haystack + last_idx, pos - last_idx); distance = u_countChar32(m->haystack + last_idx, pos - last_idx);
if (distance <= 1) score_for_char = m->max_score_per_char; if (distance <= 1) score_for_char = m->max_score_per_char;
else { else {
@ -222,8 +226,30 @@ static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions)
return final_score; return final_score;
} }
static bool create_searches(UStringSearch **searches, UChar *haystack, int32_t haystack_len, UChar *needle, int32_t needle_len, UCollator *collator) {
int32_t i = 0, pos = 0;
UErrorCode status = U_ZERO_ERROR;
static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UChar *needle, Match *match_results, int32_t *final_positions, int32_t needle_char_len, UChar *level1, UChar *level2, UChar *level3) { while (i < needle_len) {
pos = i;
U16_FWD_1(needle, i, needle_len);
if (pos == i) break;
searches[pos] = usearch_openFromCollator(needle + pos, i - pos, haystack, haystack_len, collator, NULL, &status);
if (U_FAILURE(status)) { PyErr_SetString(PyExc_ValueError, u_errorName(status)); searches[pos] = NULL; return FALSE; }
}
return TRUE;
}
static void free_searches(UStringSearch **searches, int32_t count) {
int32_t i = 0;
for (i = 0; i < count; i++) {
if (searches[i] != NULL) usearch_close(searches[i]);
searches[i] = NULL;
}
}
static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UChar *needle, Match *match_results, int32_t *final_positions, int32_t needle_char_len, UCollator *collator, UChar *level1, UChar *level2, UChar *level3) {
Stack stack = {0}; Stack stack = {0};
int32_t i = 0, maxhl = 0; int32_t i = 0, maxhl = 0;
int32_t r = 0, *positions = NULL; int32_t r = 0, *positions = NULL;
@ -231,6 +257,7 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
bool ok = FALSE; bool ok = FALSE;
MemoryItem ***memo = NULL; MemoryItem ***memo = NULL;
int32_t needle_len = u_strlen(needle); int32_t needle_len = u_strlen(needle);
UStringSearch **searches = NULL;
if (needle_len <= 0 || item_count <= 0) { if (needle_len <= 0 || item_count <= 0) {
for (i = 0; i < (int32_t)item_count; i++) match_results[i].score = 0.0; for (i = 0; i < (int32_t)item_count; i++) match_results[i].score = 0.0;
@ -240,7 +267,8 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
matches = (MatchInfo*)calloc(item_count, sizeof(MatchInfo)); matches = (MatchInfo*)calloc(item_count, sizeof(MatchInfo));
positions = (int32_t*)calloc(2*needle_len, sizeof(int32_t)); // One set of positions is the final answer and one set is working space positions = (int32_t*)calloc(2*needle_len, sizeof(int32_t)); // One set of positions is the final answer and one set is working space
if (matches == NULL || positions == NULL) {PyErr_NoMemory(); goto end;} searches = (UStringSearch**) calloc(needle_len, sizeof(UStringSearch*));
if (matches == NULL || positions == NULL || searches == NULL) {PyErr_NoMemory(); goto end;}
for (i = 0; i < (int32_t)item_count; i++) { for (i = 0; i < (int32_t)item_count; i++) {
matches[i].haystack = items[i]; matches[i].haystack = items[i];
@ -270,8 +298,10 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
} }
stack_clear(&stack); stack_clear(&stack);
clear_memory(memo, needle_len, matches[i].haystack_len); clear_memory(memo, needle_len, matches[i].haystack_len);
free_searches(searches, needle_len);
if (!create_searches(searches, matches[i].haystack, matches[i].haystack_len, needle, needle_len, collator)) goto end;
matches[i].memo = memo; matches[i].memo = memo;
match_results[i].score = process_item(&matches[i], &stack, positions); match_results[i].score = process_item(&matches[i], &stack, positions, searches);
convert_positions(positions, final_positions + i, matches[i].haystack, needle_char_len, needle_len, match_results[i].score); convert_positions(positions, final_positions + i, matches[i].haystack, needle_char_len, needle_len, match_results[i].score);
} }
@ -281,6 +311,7 @@ end:
nullfree(stack.items); nullfree(stack.items);
nullfree(matches); nullfree(matches);
nullfree(memo); nullfree(memo);
if (searches != NULL) { free_searches(searches, needle_len); nullfree(searches); }
return ok; return ok;
} }
@ -296,6 +327,7 @@ typedef struct {
UChar *level1; UChar *level1;
UChar *level2; UChar *level2;
UChar *level3; UChar *level3;
UCollator *collator;
} Matcher; } Matcher;
@ -308,6 +340,7 @@ static void free_matcher(Matcher *self) {
} }
nullfree(self->items); nullfree(self->item_lengths); nullfree(self->items); nullfree(self->item_lengths);
nullfree(self->level1); nullfree(self->level2); nullfree(self->level3); nullfree(self->level1); nullfree(self->level2); nullfree(self->level3);
if (self->collator != NULL) ucol_close(self->collator); self->collator = NULL;
} }
static void static void
Matcher_dealloc(Matcher* self) Matcher_dealloc(Matcher* self)
@ -320,10 +353,21 @@ Matcher_dealloc(Matcher* self)
static int static int
Matcher_init(Matcher *self, PyObject *args, PyObject *kwds) Matcher_init(Matcher *self, PyObject *args, PyObject *kwds)
{ {
PyObject *items = NULL, *p = NULL, *py_items = NULL, *level1 = NULL, *level2 = NULL, *level3 = NULL; PyObject *items = NULL, *p = NULL, *py_items = NULL, *level1 = NULL, *level2 = NULL, *level3 = NULL, *collator = NULL;
int32_t i = 0; int32_t i = 0;
UErrorCode status = U_ZERO_ERROR;
UCollator *col = NULL;
if (!PyArg_ParseTuple(args, "OOOOO", &items, &collator, &level1, &level2, &level3)) return -1;
// Clone the passed in collator (cloning is needed as collators are not thread safe)
if (!PyCapsule_CheckExact(collator)) { PyErr_SetString(PyExc_TypeError, "Collator must be a capsule"); return -1; }
col = (UCollator*)PyCapsule_GetPointer(collator, NULL);
if (col == NULL) return -1;
self->collator = ucol_safeClone(col, NULL, NULL, &status);
col = NULL;
if (U_FAILURE(status)) { self->collator = NULL; PyErr_SetString(PyExc_ValueError, u_errorName(status)); return -1; }
if (!PyArg_ParseTuple(args, "OOOO", &items, &level1, &level2, &level3)) return -1;
py_items = PySequence_Fast(items, "Must pass in two sequence objects"); py_items = PySequence_Fast(items, "Must pass in two sequence objects");
if (py_items == NULL) goto end; if (py_items == NULL) goto end;
self->item_count = (uint32_t)PySequence_Size(items); self->item_count = (uint32_t)PySequence_Size(items);
@ -378,7 +422,7 @@ Matcher_calculate_scores(Matcher *self, PyObject *args) {
} }
Py_BEGIN_ALLOW_THREADS; Py_BEGIN_ALLOW_THREADS;
ok = match(self->items, self->item_lengths, self->item_count, needle, matches, final_positions, needle_char_len, self->level1, self->level2, self->level3); ok = match(self->items, self->item_lengths, self->item_count, needle, matches, final_positions, needle_char_len, self->collator, self->level1, self->level2, self->level3);
Py_END_ALLOW_THREADS; Py_END_ALLOW_THREADS;
if (ok) { if (ok) {

View File

@ -6,29 +6,76 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3' __license__ = 'GPL v3'
__copyright__ = '2014, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2014, Kovid Goyal <kovid at kovidgoyal.net>'
import atexit
from math import ceil
from unicodedata import normalize from unicodedata import normalize
from threading import Thread, Lock
from Queue import Queue
from itertools import izip from itertools import izip
from future_builtins import map from future_builtins import map
from calibre import detect_ncpus as cpu_count
from calibre.constants import plugins from calibre.constants import plugins
from calibre.utils.icu import primary_sort_key, find from calibre.utils.icu import primary_sort_key, primary_find, primary_collator
DEFAULT_LEVEL1 = '/' DEFAULT_LEVEL1 = '/'
DEFAULT_LEVEL2 = '-_ 0123456789' DEFAULT_LEVEL2 = '-_ 0123456789'
DEFAULT_LEVEL3 = '.' DEFAULT_LEVEL3 = '.'
class Worker(Thread):
daemon = True
def __init__(self, requests, results):
Thread.__init__(self)
self.requests, self.results = requests, results
atexit.register(lambda : requests.put(None))
def run(self):
while True:
x = self.requests.get()
if x is None:
break
try:
self.results.put((True, self.process_query(*x)))
except:
import traceback
self.results.put((False, traceback.format_exc()))
wlock = Lock()
workers = []
def split(tasks, pool_size):
'''
Split a list into a list of sub lists, with the number of sub lists being
no more than the number of workers this server supports. Each sublist contains
2-tuples of the form (i, x) where x is an element from the original list
and i is the index of the element x in the original list.
'''
ans, count, pos = [], 0, 0
delta = int(ceil(len(tasks)/pool_size))
while count < len(tasks):
section = []
for t in tasks[pos:pos+delta]:
section.append((count, t))
count += 1
ans.append(section)
pos += delta
return ans
class Matcher(object): class Matcher(object):
def __init__(self, items, level1=DEFAULT_LEVEL1, level2=DEFAULT_LEVEL2, level3=DEFAULT_LEVEL3): def __init__(self, items, level1=DEFAULT_LEVEL1, level2=DEFAULT_LEVEL2, level3=DEFAULT_LEVEL3):
with wlock:
if not workers:
requests, results = Queue(), Queue()
w = [Worker(requests, results) for i in range(max(1, cpu_count()))]
[x.start() for x in w]
workers.extend(w)
items = map(lambda x: normalize('NFC', unicode(x)), filter(None, items)) items = map(lambda x: normalize('NFC', unicode(x)), filter(None, items))
items = tuple(map(lambda x: x.encode('utf-8'), items)) self.items = items = tuple(items)
sort_keys = tuple(map(primary_sort_key, items)) self.sort_keys = tuple(map(primary_sort_key, items))
speedup, err = plugins['matcher']
if speedup is None:
raise RuntimeError('Failed to load the matcher plugin with error: %s' % err)
self.m = speedup.Matcher(items, sort_keys, level1.encode('utf-8'), level2.encode('utf-8'), level3.encode('utf-8'))
def __call__(self, query): def __call__(self, query):
query = normalize('NFC', unicode(query)).encode('utf-8') query = normalize('NFC', unicode(query)).encode('utf-8')
@ -65,7 +112,7 @@ def process_item(ctx, haystack, needle):
if (len(haystack) - hidx < len(needle) - i): if (len(haystack) - hidx < len(needle) - i):
score = 0 score = 0
break break
pos = find(n, haystack[hidx:])[0] + hidx pos = primary_find(n, haystack[hidx:])[0] + hidx
if pos == -1: if pos == -1:
score = 0 score = 0
break break
@ -106,7 +153,7 @@ class CScorer(object):
speedup, err = plugins['matcher'] speedup, err = plugins['matcher']
if speedup is None: if speedup is None:
raise RuntimeError('Failed to load the matcher plugin with error: %s' % err) raise RuntimeError('Failed to load the matcher plugin with error: %s' % err)
self.m = speedup.Matcher(items, unicode(level1), unicode(level2), unicode(level3)) self.m = speedup.Matcher(items, primary_collator().capsule, unicode(level1), unicode(level2), unicode(level3))
def __call__(self, query): def __call__(self, query):
query = normalize('NFC', unicode(query)) query = normalize('NFC', unicode(query))
@ -120,7 +167,7 @@ def test():
c = CScorer(items) c = CScorer(items)
for q in (s, c): for q in (s, c):
print (q) print (q)
for item, (score, positions) in izip(items, q('mno')): for item, (score, positions) in izip(items, q('MNO')):
print (item, score, positions) print (item, score, positions)
def test_mem(): def test_mem():