Subsequence matcher: Use primary collation

This commit is contained in:
Kovid Goyal 2014-03-08 10:59:19 +05:30
parent 6e9afc0398
commit b672f4ed11
2 changed files with 116 additions and 25 deletions

View File

@ -171,12 +171,13 @@ static void convert_positions(int32_t *positions, int32_t *final_positions, UCha
}
}
static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions) {
UChar32 nc, hc, lc;
UChar *p;
static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions, UStringSearch **searches) {
UChar32 hc, lc;
double final_score = 0.0, score = 0.0, score_for_char = 0.0;
int32_t pos, i, j, hidx, nidx, last_idx, distance, *positions = final_positions + m->needle_len;
MemoryItem mem = {0};
UStringSearch *search = NULL;
UErrorCode status = U_ZERO_ERROR;
stack_push(stack, 0, 0, 0, 0.0, final_positions);
@ -187,11 +188,14 @@ static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions)
// No memoized result, calculate the score
for (i = nidx; i < m->needle_len;) {
nidx = i;
U16_NEXT(m->needle, i, m->needle_len, nc); // i now points to next char in needle
if (m->haystack_len - hidx < m->needle_len - nidx) { score = 0.0; break; }
p = u_strchr32(m->haystack + hidx, nc); // TODO: Use primary collation for the find
if (p == NULL) { score = 0.0; break; }
pos = (int32_t)(p - m->haystack);
U16_FWD_1(m->needle, i, m->needle_len);// i now points to next char in needle
search = searches[nidx];
if (search == NULL || m->haystack_len - hidx < m->needle_len - nidx) { score = 0.0; break; }
status = U_ZERO_ERROR; // We ignore any errors as we already know that hidx is correct
usearch_setOffset(search, hidx, &status);
status = U_ZERO_ERROR;
pos = usearch_next(search, &status);
if (pos == USEARCH_DONE) { score = 0.0; break; } // No matches found
distance = u_countChar32(m->haystack + last_idx, pos - last_idx);
if (distance <= 1) score_for_char = m->max_score_per_char;
else {
@ -222,8 +226,30 @@ static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions)
return final_score;
}
static bool create_searches(UStringSearch **searches, UChar *haystack, int32_t haystack_len, UChar *needle, int32_t needle_len, UCollator *collator) {
int32_t i = 0, pos = 0;
UErrorCode status = U_ZERO_ERROR;
static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UChar *needle, Match *match_results, int32_t *final_positions, int32_t needle_char_len, UChar *level1, UChar *level2, UChar *level3) {
while (i < needle_len) {
pos = i;
U16_FWD_1(needle, i, needle_len);
if (pos == i) break;
searches[pos] = usearch_openFromCollator(needle + pos, i - pos, haystack, haystack_len, collator, NULL, &status);
if (U_FAILURE(status)) { PyErr_SetString(PyExc_ValueError, u_errorName(status)); searches[pos] = NULL; return FALSE; }
}
return TRUE;
}
static void free_searches(UStringSearch **searches, int32_t count) {
int32_t i = 0;
for (i = 0; i < count; i++) {
if (searches[i] != NULL) usearch_close(searches[i]);
searches[i] = NULL;
}
}
static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UChar *needle, Match *match_results, int32_t *final_positions, int32_t needle_char_len, UCollator *collator, UChar *level1, UChar *level2, UChar *level3) {
Stack stack = {0};
int32_t i = 0, maxhl = 0;
int32_t r = 0, *positions = NULL;
@ -231,6 +257,7 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
bool ok = FALSE;
MemoryItem ***memo = NULL;
int32_t needle_len = u_strlen(needle);
UStringSearch **searches = NULL;
if (needle_len <= 0 || item_count <= 0) {
for (i = 0; i < (int32_t)item_count; i++) match_results[i].score = 0.0;
@ -240,7 +267,8 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
matches = (MatchInfo*)calloc(item_count, sizeof(MatchInfo));
positions = (int32_t*)calloc(2*needle_len, sizeof(int32_t)); // One set of positions is the final answer and one set is working space
if (matches == NULL || positions == NULL) {PyErr_NoMemory(); goto end;}
searches = (UStringSearch**) calloc(needle_len, sizeof(UStringSearch*));
if (matches == NULL || positions == NULL || searches == NULL) {PyErr_NoMemory(); goto end;}
for (i = 0; i < (int32_t)item_count; i++) {
matches[i].haystack = items[i];
@ -270,8 +298,10 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
}
stack_clear(&stack);
clear_memory(memo, needle_len, matches[i].haystack_len);
free_searches(searches, needle_len);
if (!create_searches(searches, matches[i].haystack, matches[i].haystack_len, needle, needle_len, collator)) goto end;
matches[i].memo = memo;
match_results[i].score = process_item(&matches[i], &stack, positions);
match_results[i].score = process_item(&matches[i], &stack, positions, searches);
convert_positions(positions, final_positions + i, matches[i].haystack, needle_char_len, needle_len, match_results[i].score);
}
@ -281,6 +311,7 @@ end:
nullfree(stack.items);
nullfree(matches);
nullfree(memo);
if (searches != NULL) { free_searches(searches, needle_len); nullfree(searches); }
return ok;
}
@ -296,6 +327,7 @@ typedef struct {
UChar *level1;
UChar *level2;
UChar *level3;
UCollator *collator;
} Matcher;
@ -308,6 +340,7 @@ static void free_matcher(Matcher *self) {
}
nullfree(self->items); nullfree(self->item_lengths);
nullfree(self->level1); nullfree(self->level2); nullfree(self->level3);
if (self->collator != NULL) ucol_close(self->collator); self->collator = NULL;
}
static void
Matcher_dealloc(Matcher* self)
@ -320,10 +353,21 @@ Matcher_dealloc(Matcher* self)
static int
Matcher_init(Matcher *self, PyObject *args, PyObject *kwds)
{
PyObject *items = NULL, *p = NULL, *py_items = NULL, *level1 = NULL, *level2 = NULL, *level3 = NULL;
PyObject *items = NULL, *p = NULL, *py_items = NULL, *level1 = NULL, *level2 = NULL, *level3 = NULL, *collator = NULL;
int32_t i = 0;
UErrorCode status = U_ZERO_ERROR;
UCollator *col = NULL;
if (!PyArg_ParseTuple(args, "OOOOO", &items, &collator, &level1, &level2, &level3)) return -1;
// Clone the passed in collator (cloning is needed as collators are not thread safe)
if (!PyCapsule_CheckExact(collator)) { PyErr_SetString(PyExc_TypeError, "Collator must be a capsule"); return -1; }
col = (UCollator*)PyCapsule_GetPointer(collator, NULL);
if (col == NULL) return -1;
self->collator = ucol_safeClone(col, NULL, NULL, &status);
col = NULL;
if (U_FAILURE(status)) { self->collator = NULL; PyErr_SetString(PyExc_ValueError, u_errorName(status)); return -1; }
if (!PyArg_ParseTuple(args, "OOOO", &items, &level1, &level2, &level3)) return -1;
py_items = PySequence_Fast(items, "Must pass in two sequence objects");
if (py_items == NULL) goto end;
self->item_count = (uint32_t)PySequence_Size(items);
@ -378,7 +422,7 @@ Matcher_calculate_scores(Matcher *self, PyObject *args) {
}
Py_BEGIN_ALLOW_THREADS;
ok = match(self->items, self->item_lengths, self->item_count, needle, matches, final_positions, needle_char_len, self->level1, self->level2, self->level3);
ok = match(self->items, self->item_lengths, self->item_count, needle, matches, final_positions, needle_char_len, self->collator, self->level1, self->level2, self->level3);
Py_END_ALLOW_THREADS;
if (ok) {

View File

@ -6,29 +6,76 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3'
__copyright__ = '2014, Kovid Goyal <kovid at kovidgoyal.net>'
import atexit
from math import ceil
from unicodedata import normalize
from threading import Thread, Lock
from Queue import Queue
from itertools import izip
from future_builtins import map
from calibre import detect_ncpus as cpu_count
from calibre.constants import plugins
from calibre.utils.icu import primary_sort_key, find
from calibre.utils.icu import primary_sort_key, primary_find, primary_collator
DEFAULT_LEVEL1 = '/'
DEFAULT_LEVEL2 = '-_ 0123456789'
DEFAULT_LEVEL3 = '.'
class Worker(Thread):
daemon = True
def __init__(self, requests, results):
Thread.__init__(self)
self.requests, self.results = requests, results
atexit.register(lambda : requests.put(None))
def run(self):
while True:
x = self.requests.get()
if x is None:
break
try:
self.results.put((True, self.process_query(*x)))
except:
import traceback
self.results.put((False, traceback.format_exc()))
wlock = Lock()
workers = []
def split(tasks, pool_size):
'''
Split a list into a list of sub lists, with the number of sub lists being
no more than the number of workers this server supports. Each sublist contains
2-tuples of the form (i, x) where x is an element from the original list
and i is the index of the element x in the original list.
'''
ans, count, pos = [], 0, 0
delta = int(ceil(len(tasks)/pool_size))
while count < len(tasks):
section = []
for t in tasks[pos:pos+delta]:
section.append((count, t))
count += 1
ans.append(section)
pos += delta
return ans
class Matcher(object):
def __init__(self, items, level1=DEFAULT_LEVEL1, level2=DEFAULT_LEVEL2, level3=DEFAULT_LEVEL3):
with wlock:
if not workers:
requests, results = Queue(), Queue()
w = [Worker(requests, results) for i in range(max(1, cpu_count()))]
[x.start() for x in w]
workers.extend(w)
items = map(lambda x: normalize('NFC', unicode(x)), filter(None, items))
items = tuple(map(lambda x: x.encode('utf-8'), items))
sort_keys = tuple(map(primary_sort_key, items))
speedup, err = plugins['matcher']
if speedup is None:
raise RuntimeError('Failed to load the matcher plugin with error: %s' % err)
self.m = speedup.Matcher(items, sort_keys, level1.encode('utf-8'), level2.encode('utf-8'), level3.encode('utf-8'))
self.items = items = tuple(items)
self.sort_keys = tuple(map(primary_sort_key, items))
def __call__(self, query):
query = normalize('NFC', unicode(query)).encode('utf-8')
@ -65,7 +112,7 @@ def process_item(ctx, haystack, needle):
if (len(haystack) - hidx < len(needle) - i):
score = 0
break
pos = find(n, haystack[hidx:])[0] + hidx
pos = primary_find(n, haystack[hidx:])[0] + hidx
if pos == -1:
score = 0
break
@ -106,7 +153,7 @@ class CScorer(object):
speedup, err = plugins['matcher']
if speedup is None:
raise RuntimeError('Failed to load the matcher plugin with error: %s' % err)
self.m = speedup.Matcher(items, unicode(level1), unicode(level2), unicode(level3))
self.m = speedup.Matcher(items, primary_collator().capsule, unicode(level1), unicode(level2), unicode(level3))
def __call__(self, query):
query = normalize('NFC', unicode(query))
@ -120,7 +167,7 @@ def test():
c = CScorer(items)
for q in (s, c):
print (q)
for item, (score, positions) in izip(items, q('mno')):
for item, (score, positions) in izip(items, q('MNO')):
print (item, score, positions)
def test_mem():