mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-07 10:14:46 -04:00
Subsequence matcher: Use primary collation
This commit is contained in:
parent
6e9afc0398
commit
b672f4ed11
@ -171,12 +171,13 @@ static void convert_positions(int32_t *positions, int32_t *final_positions, UCha
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions) {
|
static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions, UStringSearch **searches) {
|
||||||
UChar32 nc, hc, lc;
|
UChar32 hc, lc;
|
||||||
UChar *p;
|
|
||||||
double final_score = 0.0, score = 0.0, score_for_char = 0.0;
|
double final_score = 0.0, score = 0.0, score_for_char = 0.0;
|
||||||
int32_t pos, i, j, hidx, nidx, last_idx, distance, *positions = final_positions + m->needle_len;
|
int32_t pos, i, j, hidx, nidx, last_idx, distance, *positions = final_positions + m->needle_len;
|
||||||
MemoryItem mem = {0};
|
MemoryItem mem = {0};
|
||||||
|
UStringSearch *search = NULL;
|
||||||
|
UErrorCode status = U_ZERO_ERROR;
|
||||||
|
|
||||||
stack_push(stack, 0, 0, 0, 0.0, final_positions);
|
stack_push(stack, 0, 0, 0, 0.0, final_positions);
|
||||||
|
|
||||||
@ -187,11 +188,14 @@ static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions)
|
|||||||
// No memoized result, calculate the score
|
// No memoized result, calculate the score
|
||||||
for (i = nidx; i < m->needle_len;) {
|
for (i = nidx; i < m->needle_len;) {
|
||||||
nidx = i;
|
nidx = i;
|
||||||
U16_NEXT(m->needle, i, m->needle_len, nc); // i now points to next char in needle
|
U16_FWD_1(m->needle, i, m->needle_len);// i now points to next char in needle
|
||||||
if (m->haystack_len - hidx < m->needle_len - nidx) { score = 0.0; break; }
|
search = searches[nidx];
|
||||||
p = u_strchr32(m->haystack + hidx, nc); // TODO: Use primary collation for the find
|
if (search == NULL || m->haystack_len - hidx < m->needle_len - nidx) { score = 0.0; break; }
|
||||||
if (p == NULL) { score = 0.0; break; }
|
status = U_ZERO_ERROR; // We ignore any errors as we already know that hidx is correct
|
||||||
pos = (int32_t)(p - m->haystack);
|
usearch_setOffset(search, hidx, &status);
|
||||||
|
status = U_ZERO_ERROR;
|
||||||
|
pos = usearch_next(search, &status);
|
||||||
|
if (pos == USEARCH_DONE) { score = 0.0; break; } // No matches found
|
||||||
distance = u_countChar32(m->haystack + last_idx, pos - last_idx);
|
distance = u_countChar32(m->haystack + last_idx, pos - last_idx);
|
||||||
if (distance <= 1) score_for_char = m->max_score_per_char;
|
if (distance <= 1) score_for_char = m->max_score_per_char;
|
||||||
else {
|
else {
|
||||||
@ -222,8 +226,30 @@ static double process_item(MatchInfo *m, Stack *stack, int32_t *final_positions)
|
|||||||
return final_score;
|
return final_score;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool create_searches(UStringSearch **searches, UChar *haystack, int32_t haystack_len, UChar *needle, int32_t needle_len, UCollator *collator) {
|
||||||
|
int32_t i = 0, pos = 0;
|
||||||
|
UErrorCode status = U_ZERO_ERROR;
|
||||||
|
|
||||||
static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UChar *needle, Match *match_results, int32_t *final_positions, int32_t needle_char_len, UChar *level1, UChar *level2, UChar *level3) {
|
while (i < needle_len) {
|
||||||
|
pos = i;
|
||||||
|
U16_FWD_1(needle, i, needle_len);
|
||||||
|
if (pos == i) break;
|
||||||
|
searches[pos] = usearch_openFromCollator(needle + pos, i - pos, haystack, haystack_len, collator, NULL, &status);
|
||||||
|
if (U_FAILURE(status)) { PyErr_SetString(PyExc_ValueError, u_errorName(status)); searches[pos] = NULL; return FALSE; }
|
||||||
|
}
|
||||||
|
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void free_searches(UStringSearch **searches, int32_t count) {
|
||||||
|
int32_t i = 0;
|
||||||
|
for (i = 0; i < count; i++) {
|
||||||
|
if (searches[i] != NULL) usearch_close(searches[i]);
|
||||||
|
searches[i] = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UChar *needle, Match *match_results, int32_t *final_positions, int32_t needle_char_len, UCollator *collator, UChar *level1, UChar *level2, UChar *level3) {
|
||||||
Stack stack = {0};
|
Stack stack = {0};
|
||||||
int32_t i = 0, maxhl = 0;
|
int32_t i = 0, maxhl = 0;
|
||||||
int32_t r = 0, *positions = NULL;
|
int32_t r = 0, *positions = NULL;
|
||||||
@ -231,6 +257,7 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
|
|||||||
bool ok = FALSE;
|
bool ok = FALSE;
|
||||||
MemoryItem ***memo = NULL;
|
MemoryItem ***memo = NULL;
|
||||||
int32_t needle_len = u_strlen(needle);
|
int32_t needle_len = u_strlen(needle);
|
||||||
|
UStringSearch **searches = NULL;
|
||||||
|
|
||||||
if (needle_len <= 0 || item_count <= 0) {
|
if (needle_len <= 0 || item_count <= 0) {
|
||||||
for (i = 0; i < (int32_t)item_count; i++) match_results[i].score = 0.0;
|
for (i = 0; i < (int32_t)item_count; i++) match_results[i].score = 0.0;
|
||||||
@ -240,7 +267,8 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
|
|||||||
|
|
||||||
matches = (MatchInfo*)calloc(item_count, sizeof(MatchInfo));
|
matches = (MatchInfo*)calloc(item_count, sizeof(MatchInfo));
|
||||||
positions = (int32_t*)calloc(2*needle_len, sizeof(int32_t)); // One set of positions is the final answer and one set is working space
|
positions = (int32_t*)calloc(2*needle_len, sizeof(int32_t)); // One set of positions is the final answer and one set is working space
|
||||||
if (matches == NULL || positions == NULL) {PyErr_NoMemory(); goto end;}
|
searches = (UStringSearch**) calloc(needle_len, sizeof(UStringSearch*));
|
||||||
|
if (matches == NULL || positions == NULL || searches == NULL) {PyErr_NoMemory(); goto end;}
|
||||||
|
|
||||||
for (i = 0; i < (int32_t)item_count; i++) {
|
for (i = 0; i < (int32_t)item_count; i++) {
|
||||||
matches[i].haystack = items[i];
|
matches[i].haystack = items[i];
|
||||||
@ -270,8 +298,10 @@ static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UCh
|
|||||||
}
|
}
|
||||||
stack_clear(&stack);
|
stack_clear(&stack);
|
||||||
clear_memory(memo, needle_len, matches[i].haystack_len);
|
clear_memory(memo, needle_len, matches[i].haystack_len);
|
||||||
|
free_searches(searches, needle_len);
|
||||||
|
if (!create_searches(searches, matches[i].haystack, matches[i].haystack_len, needle, needle_len, collator)) goto end;
|
||||||
matches[i].memo = memo;
|
matches[i].memo = memo;
|
||||||
match_results[i].score = process_item(&matches[i], &stack, positions);
|
match_results[i].score = process_item(&matches[i], &stack, positions, searches);
|
||||||
convert_positions(positions, final_positions + i, matches[i].haystack, needle_char_len, needle_len, match_results[i].score);
|
convert_positions(positions, final_positions + i, matches[i].haystack, needle_char_len, needle_len, match_results[i].score);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -281,6 +311,7 @@ end:
|
|||||||
nullfree(stack.items);
|
nullfree(stack.items);
|
||||||
nullfree(matches);
|
nullfree(matches);
|
||||||
nullfree(memo);
|
nullfree(memo);
|
||||||
|
if (searches != NULL) { free_searches(searches, needle_len); nullfree(searches); }
|
||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,6 +327,7 @@ typedef struct {
|
|||||||
UChar *level1;
|
UChar *level1;
|
||||||
UChar *level2;
|
UChar *level2;
|
||||||
UChar *level3;
|
UChar *level3;
|
||||||
|
UCollator *collator;
|
||||||
|
|
||||||
} Matcher;
|
} Matcher;
|
||||||
|
|
||||||
@ -308,6 +340,7 @@ static void free_matcher(Matcher *self) {
|
|||||||
}
|
}
|
||||||
nullfree(self->items); nullfree(self->item_lengths);
|
nullfree(self->items); nullfree(self->item_lengths);
|
||||||
nullfree(self->level1); nullfree(self->level2); nullfree(self->level3);
|
nullfree(self->level1); nullfree(self->level2); nullfree(self->level3);
|
||||||
|
if (self->collator != NULL) ucol_close(self->collator); self->collator = NULL;
|
||||||
}
|
}
|
||||||
static void
|
static void
|
||||||
Matcher_dealloc(Matcher* self)
|
Matcher_dealloc(Matcher* self)
|
||||||
@ -320,10 +353,21 @@ Matcher_dealloc(Matcher* self)
|
|||||||
static int
|
static int
|
||||||
Matcher_init(Matcher *self, PyObject *args, PyObject *kwds)
|
Matcher_init(Matcher *self, PyObject *args, PyObject *kwds)
|
||||||
{
|
{
|
||||||
PyObject *items = NULL, *p = NULL, *py_items = NULL, *level1 = NULL, *level2 = NULL, *level3 = NULL;
|
PyObject *items = NULL, *p = NULL, *py_items = NULL, *level1 = NULL, *level2 = NULL, *level3 = NULL, *collator = NULL;
|
||||||
int32_t i = 0;
|
int32_t i = 0;
|
||||||
|
UErrorCode status = U_ZERO_ERROR;
|
||||||
|
UCollator *col = NULL;
|
||||||
|
|
||||||
|
if (!PyArg_ParseTuple(args, "OOOOO", &items, &collator, &level1, &level2, &level3)) return -1;
|
||||||
|
|
||||||
|
// Clone the passed in collator (cloning is needed as collators are not thread safe)
|
||||||
|
if (!PyCapsule_CheckExact(collator)) { PyErr_SetString(PyExc_TypeError, "Collator must be a capsule"); return -1; }
|
||||||
|
col = (UCollator*)PyCapsule_GetPointer(collator, NULL);
|
||||||
|
if (col == NULL) return -1;
|
||||||
|
self->collator = ucol_safeClone(col, NULL, NULL, &status);
|
||||||
|
col = NULL;
|
||||||
|
if (U_FAILURE(status)) { self->collator = NULL; PyErr_SetString(PyExc_ValueError, u_errorName(status)); return -1; }
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(args, "OOOO", &items, &level1, &level2, &level3)) return -1;
|
|
||||||
py_items = PySequence_Fast(items, "Must pass in two sequence objects");
|
py_items = PySequence_Fast(items, "Must pass in two sequence objects");
|
||||||
if (py_items == NULL) goto end;
|
if (py_items == NULL) goto end;
|
||||||
self->item_count = (uint32_t)PySequence_Size(items);
|
self->item_count = (uint32_t)PySequence_Size(items);
|
||||||
@ -378,7 +422,7 @@ Matcher_calculate_scores(Matcher *self, PyObject *args) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Py_BEGIN_ALLOW_THREADS;
|
Py_BEGIN_ALLOW_THREADS;
|
||||||
ok = match(self->items, self->item_lengths, self->item_count, needle, matches, final_positions, needle_char_len, self->level1, self->level2, self->level3);
|
ok = match(self->items, self->item_lengths, self->item_count, needle, matches, final_positions, needle_char_len, self->collator, self->level1, self->level2, self->level3);
|
||||||
Py_END_ALLOW_THREADS;
|
Py_END_ALLOW_THREADS;
|
||||||
|
|
||||||
if (ok) {
|
if (ok) {
|
||||||
|
@ -6,29 +6,76 @@ from __future__ import (unicode_literals, division, absolute_import,
|
|||||||
__license__ = 'GPL v3'
|
__license__ = 'GPL v3'
|
||||||
__copyright__ = '2014, Kovid Goyal <kovid at kovidgoyal.net>'
|
__copyright__ = '2014, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
from math import ceil
|
||||||
from unicodedata import normalize
|
from unicodedata import normalize
|
||||||
|
from threading import Thread, Lock
|
||||||
|
from Queue import Queue
|
||||||
|
|
||||||
from itertools import izip
|
from itertools import izip
|
||||||
from future_builtins import map
|
from future_builtins import map
|
||||||
|
|
||||||
|
from calibre import detect_ncpus as cpu_count
|
||||||
from calibre.constants import plugins
|
from calibre.constants import plugins
|
||||||
from calibre.utils.icu import primary_sort_key, find
|
from calibre.utils.icu import primary_sort_key, primary_find, primary_collator
|
||||||
|
|
||||||
DEFAULT_LEVEL1 = '/'
|
DEFAULT_LEVEL1 = '/'
|
||||||
DEFAULT_LEVEL2 = '-_ 0123456789'
|
DEFAULT_LEVEL2 = '-_ 0123456789'
|
||||||
DEFAULT_LEVEL3 = '.'
|
DEFAULT_LEVEL3 = '.'
|
||||||
|
|
||||||
|
class Worker(Thread):
|
||||||
|
|
||||||
|
daemon = True
|
||||||
|
|
||||||
|
def __init__(self, requests, results):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.requests, self.results = requests, results
|
||||||
|
atexit.register(lambda : requests.put(None))
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while True:
|
||||||
|
x = self.requests.get()
|
||||||
|
if x is None:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
self.results.put((True, self.process_query(*x)))
|
||||||
|
except:
|
||||||
|
import traceback
|
||||||
|
self.results.put((False, traceback.format_exc()))
|
||||||
|
wlock = Lock()
|
||||||
|
workers = []
|
||||||
|
|
||||||
|
def split(tasks, pool_size):
|
||||||
|
'''
|
||||||
|
Split a list into a list of sub lists, with the number of sub lists being
|
||||||
|
no more than the number of workers this server supports. Each sublist contains
|
||||||
|
2-tuples of the form (i, x) where x is an element from the original list
|
||||||
|
and i is the index of the element x in the original list.
|
||||||
|
'''
|
||||||
|
ans, count, pos = [], 0, 0
|
||||||
|
delta = int(ceil(len(tasks)/pool_size))
|
||||||
|
while count < len(tasks):
|
||||||
|
section = []
|
||||||
|
for t in tasks[pos:pos+delta]:
|
||||||
|
section.append((count, t))
|
||||||
|
count += 1
|
||||||
|
ans.append(section)
|
||||||
|
pos += delta
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
class Matcher(object):
|
class Matcher(object):
|
||||||
|
|
||||||
def __init__(self, items, level1=DEFAULT_LEVEL1, level2=DEFAULT_LEVEL2, level3=DEFAULT_LEVEL3):
|
def __init__(self, items, level1=DEFAULT_LEVEL1, level2=DEFAULT_LEVEL2, level3=DEFAULT_LEVEL3):
|
||||||
|
with wlock:
|
||||||
|
if not workers:
|
||||||
|
requests, results = Queue(), Queue()
|
||||||
|
w = [Worker(requests, results) for i in range(max(1, cpu_count()))]
|
||||||
|
[x.start() for x in w]
|
||||||
|
workers.extend(w)
|
||||||
items = map(lambda x: normalize('NFC', unicode(x)), filter(None, items))
|
items = map(lambda x: normalize('NFC', unicode(x)), filter(None, items))
|
||||||
items = tuple(map(lambda x: x.encode('utf-8'), items))
|
self.items = items = tuple(items)
|
||||||
sort_keys = tuple(map(primary_sort_key, items))
|
self.sort_keys = tuple(map(primary_sort_key, items))
|
||||||
|
|
||||||
speedup, err = plugins['matcher']
|
|
||||||
if speedup is None:
|
|
||||||
raise RuntimeError('Failed to load the matcher plugin with error: %s' % err)
|
|
||||||
self.m = speedup.Matcher(items, sort_keys, level1.encode('utf-8'), level2.encode('utf-8'), level3.encode('utf-8'))
|
|
||||||
|
|
||||||
def __call__(self, query):
|
def __call__(self, query):
|
||||||
query = normalize('NFC', unicode(query)).encode('utf-8')
|
query = normalize('NFC', unicode(query)).encode('utf-8')
|
||||||
@ -65,7 +112,7 @@ def process_item(ctx, haystack, needle):
|
|||||||
if (len(haystack) - hidx < len(needle) - i):
|
if (len(haystack) - hidx < len(needle) - i):
|
||||||
score = 0
|
score = 0
|
||||||
break
|
break
|
||||||
pos = find(n, haystack[hidx:])[0] + hidx
|
pos = primary_find(n, haystack[hidx:])[0] + hidx
|
||||||
if pos == -1:
|
if pos == -1:
|
||||||
score = 0
|
score = 0
|
||||||
break
|
break
|
||||||
@ -106,7 +153,7 @@ class CScorer(object):
|
|||||||
speedup, err = plugins['matcher']
|
speedup, err = plugins['matcher']
|
||||||
if speedup is None:
|
if speedup is None:
|
||||||
raise RuntimeError('Failed to load the matcher plugin with error: %s' % err)
|
raise RuntimeError('Failed to load the matcher plugin with error: %s' % err)
|
||||||
self.m = speedup.Matcher(items, unicode(level1), unicode(level2), unicode(level3))
|
self.m = speedup.Matcher(items, primary_collator().capsule, unicode(level1), unicode(level2), unicode(level3))
|
||||||
|
|
||||||
def __call__(self, query):
|
def __call__(self, query):
|
||||||
query = normalize('NFC', unicode(query))
|
query = normalize('NFC', unicode(query))
|
||||||
@ -120,7 +167,7 @@ def test():
|
|||||||
c = CScorer(items)
|
c = CScorer(items)
|
||||||
for q in (s, c):
|
for q in (s, c):
|
||||||
print (q)
|
print (q)
|
||||||
for item, (score, positions) in izip(items, q('mno')):
|
for item, (score, positions) in izip(items, q('MNO')):
|
||||||
print (item, score, positions)
|
print (item, score, positions)
|
||||||
|
|
||||||
def test_mem():
|
def test_mem():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user