Use ICU to add full unicode support to the subsequence matcher

This commit is contained in:
Kovid Goyal 2014-03-04 22:34:40 +05:30
parent 27f43a7506
commit 8c67730759
3 changed files with 102 additions and 65 deletions

View File

@ -179,7 +179,10 @@ extensions = [
Extension('matcher', Extension('matcher',
['calibre/gui2/tweak_book/matcher.c'], ['calibre/gui2/tweak_book/matcher.c'],
inc_dirs=(['calibre/utils/chm'] if iswindows else []) # For stdint.h libraries=icu_libs,
lib_dirs=icu_lib_dirs,
cflags=icu_cflags,
inc_dirs=icu_inc_dirs # + (['calibre/utils/chm'] if iswindows else []) # For stdint.h
), ),
Extension('podofo', Extension('podofo',

View File

@ -10,19 +10,20 @@
#include <float.h> #include <float.h>
#include <stdlib.h> #include <stdlib.h>
#include <search.h> #include <search.h>
#include <unicode/uchar.h>
#include <unicode/ustring.h>
#include <unicode/utf16.h>
#ifdef _MSC_VER #ifdef _MSC_VER
#include "stdint.h"
// inline does not work with the visual studio C compiler // inline does not work with the visual studio C compiler
#define inline #define inline
#define qsort qsort_s #define qsort qsort_s
#else #else
#include <stdint.h>
#define qsort qsort_r #define qsort qsort_r
#endif #endif
typedef uint8_t bool; typedef unsigned char bool;
#define TRUE 1 #define TRUE 1
#define FALSE 0 #define FALSE 0
#define MIN(x, y) ((x < y) ? x : y) #define MIN(x, y) ((x < y) ? x : y)
@ -31,26 +32,29 @@ typedef uint8_t bool;
// Algorithm to sort items by subsequence score {{{ // Algorithm to sort items by subsequence score {{{
typedef struct { typedef struct {
char *haystack; UChar *haystack;
uint32_t haystack_len; int32_t haystack_len;
char *needle; UChar *needle;
uint32_t needle_len; int32_t needle_len;
double max_score_per_char; double max_score_per_char;
double *memo; double *memo;
UChar *level1;
UChar *level2;
UChar *level3;
} MatchInfo; } MatchInfo;
typedef struct { typedef struct {
char *item; UChar *item;
char *sort_key; char *sort_key;
uint32_t sort_key_len; uint32_t sort_key_len;
PyObject *py_item; PyObject *py_item;
double score; double score;
} Match; } Match;
static double recursive_match(MatchInfo *m, uint32_t haystack_idx, uint32_t needle_idx, uint32_t last_idx, double score) { static double recursive_match(MatchInfo *m, int32_t haystack_idx, int32_t needle_idx, int32_t last_idx, double score) {
double seen_score = 0.0, memoized = DBL_MAX, score_for_char, factor, sub_score; double seen_score = 0.0, memoized = DBL_MAX, score_for_char, factor, sub_score;
uint32_t i = 0, j = 0, distance; int32_t i = 0, j = 0, distance, curri;
char c, d, last, curr; UChar32 c, d, last;
bool found; bool found;
// do we have a memoized result we can return? // do we have a memoized result we can return?
@ -63,42 +67,34 @@ static double recursive_match(MatchInfo *m, uint32_t haystack_idx, uint32_t need
score = 0.0; score = 0.0;
goto memoize; goto memoize;
} }
for (i = needle_idx; i < m->needle_len; i++) { for (i = needle_idx; i < m->needle_len; ) {
c = m->needle[i]; curri = i;
U16_NEXT(m->needle, i, m->needle_len, c); // i now points to the next codepoint
found = FALSE; found = FALSE;
// similar to above, we'll stop iterating when we know we're too close // similar to above, we'll stop iterating when we know we're too close
// to the end of the string to possibly match // to the end of the string to possibly match
for (j = haystack_idx; for (j = haystack_idx; j <= m->haystack_len - (m->needle_len - curri); ) {
j <= m->haystack_len - (m->needle_len - i); haystack_idx = j;
j++, haystack_idx++) { U16_NEXT(m->haystack, j, m->haystack_len, d); // j now points to the next codepoint
d = m->haystack[j];
if (d >= 'A' && d <= 'Z') {
d += 'a' - 'A'; // add 32 to downcase
}
if (c == d) { if (u_foldCase(c, U_FOLD_CASE_DEFAULT) == u_foldCase(d, U_FOLD_CASE_DEFAULT)) {
found = TRUE; found = TRUE;
// calculate score // calculate score
score_for_char = m->max_score_per_char; score_for_char = m->max_score_per_char;
distance = j - last_idx; distance = haystack_idx - last_idx;
if (distance > 1) { if (distance > 1) {
factor = 1.0; factor = 1.0;
last = m->haystack[j - 1]; U16_GET(m->haystack, haystack_idx - 1, haystack_idx - 1, m->haystack_len, last);
curr = m->haystack[j]; // case matters, so get again if (u_strchr32(m->level1, last))
if (last == '/')
factor = 0.9; factor = 0.9;
else if (last == '-' || else if (u_strchr32(m->level2, last))
last == '_' ||
last == ' ' ||
(last >= '0' && last <= '9'))
factor = 0.8; factor = 0.8;
else if (last >= 'a' && last <= 'z' && else if (u_isULowercase(last) && u_isUUppercase(d))
curr >= 'A' && curr <= 'Z') factor = 0.8; // CamelCase
factor = 0.8; else if (u_strchr32(m->level3, last))
else if (last == '.')
factor = 0.7; factor = 0.7;
else else
// if no "special" chars behind char, factor diminishes // if no "special" chars behind char, factor diminishes
@ -107,7 +103,7 @@ static double recursive_match(MatchInfo *m, uint32_t haystack_idx, uint32_t need
score_for_char *= factor; score_for_char *= factor;
} }
if (++j < m->haystack_len) { if (j < m->haystack_len) {
// bump cursor one char to the right and // bump cursor one char to the right and
// use recursion to try and find a better match // use recursion to try and find a better match
sub_score = recursive_match(m, j, i, last_idx, score); sub_score = recursive_match(m, j, i, last_idx, score);
@ -118,7 +114,7 @@ static double recursive_match(MatchInfo *m, uint32_t haystack_idx, uint32_t need
last_idx = haystack_idx + 1; last_idx = haystack_idx + 1;
break; break;
} }
} } // for(j)
if (!found) { if (!found) {
score = 0.0; score = 0.0;
@ -133,7 +129,7 @@ memoize:
return score; return score;
} }
static bool match(char **items, uint32_t *item_lengths, uint32_t item_count, char *needle, uint32_t needle_len, Match *match_results) { static bool match(UChar **items, int32_t *item_lengths, uint32_t item_count, UChar *needle, int32_t needle_len, Match *match_results, UChar *level1, UChar *level2, UChar *level3) {
uint32_t i = 0, maxhl = 0, n = 0; uint32_t i = 0, maxhl = 0, n = 0;
MatchInfo *matches = NULL; MatchInfo *matches = NULL;
bool ok = FALSE; bool ok = FALSE;
@ -154,6 +150,9 @@ static bool match(char **items, uint32_t *item_lengths, uint32_t item_count, cha
matches[i].needle = needle; matches[i].needle = needle;
matches[i].needle_len = needle_len; matches[i].needle_len = needle_len;
matches[i].max_score_per_char = (1.0 / matches[i].haystack_len + 1.0 / needle_len) / 2.0; matches[i].max_score_per_char = (1.0 / matches[i].haystack_len + 1.0 / needle_len) / 2.0;
matches[i].level1 = level1;
matches[i].level2 = level2;
matches[i].level3 = level3;
maxhl = MAX(maxhl, matches[i].haystack_len); maxhl = MAX(maxhl, matches[i].haystack_len);
} }
maxhl *= needle_len; maxhl *= needle_len;
@ -192,54 +191,78 @@ int cmp_score(const void *a, const void *b, void *arg)
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD
// Type-specific fields go here. // Type-specific fields go here.
char **items; UChar **items;
char **sort_items; char **sort_items;
uint32_t item_count; uint32_t item_count;
uint32_t *item_lengths; int32_t *item_lengths;
uint32_t *sort_item_lengths; int32_t *sort_item_lengths;
PyObject *py_items; PyObject *py_items;
PyObject *py_sort_keys; PyObject *py_sort_keys;
UChar *level1;
UChar *level2;
UChar *level3;
} Matcher; } Matcher;
// Matcher.__init__() {{{ // Matcher.__init__() {{{
#define FREE_MATCHER nullfree(self->items); nullfree(self->sort_items); nullfree(self->item_lengths); nullfree(self->sort_item_lengths); Py_XDECREF(self->py_items); Py_XDECREF(self->py_sort_keys); static void free_matcher(Matcher *self) {
uint32_t i = 0;
if (self->items != NULL) {
for (i = 0; i < self->item_count; i++) { nullfree(self->items[i]); }
}
nullfree(self->items); nullfree(self->sort_items); nullfree(self->item_lengths); nullfree(self->sort_item_lengths); Py_XDECREF(self->py_items); Py_XDECREF(self->py_sort_keys);
nullfree(self->level1); nullfree(self->level2); nullfree(self->level3);
}
static void static void
Matcher_dealloc(Matcher* self) Matcher_dealloc(Matcher* self)
{ {
FREE_MATCHER free_matcher(self);
self->ob_type->tp_free((PyObject*)self); self->ob_type->tp_free((PyObject*)self);
} }
#define alloc_uchar(x) (x * 3 + 1)
static int static int
Matcher_init(Matcher *self, PyObject *args, PyObject *kwds) Matcher_init(Matcher *self, PyObject *args, PyObject *kwds)
{ {
PyObject *items = NULL, *sort_keys = NULL, *p = NULL; PyObject *items = NULL, *sort_keys = NULL, *p = NULL;
uint32_t count = 0; char *utf8 = NULL, *level1 = NULL, *level2 = NULL, *level3 = NULL;
Py_ssize_t i = 0; int32_t i = 0;
Py_ssize_t cap = 0, l1s, l2s, l3s;
UErrorCode status = U_ZERO_ERROR;
if (!PyArg_ParseTuple(args, "OO", &items, &sort_keys)) return -1; if (!PyArg_ParseTuple(args, "OOs#s#s#", &items, &sort_keys, &level1, &l1s, &level2, &l2s, &level3, &l3s)) return -1;
self->py_items = PySequence_Fast(items, "Must pass in two sequence objects"); self->py_items = PySequence_Fast(items, "Must pass in two sequence objects");
self->py_sort_keys = PySequence_Fast(sort_keys, "Must pass in two sequence objects"); self->py_sort_keys = PySequence_Fast(sort_keys, "Must pass in two sequence objects");
if (self->py_items == NULL || self->py_sort_keys == NULL) goto end; if (self->py_items == NULL || self->py_sort_keys == NULL) goto end;
count = (uint32_t)PySequence_Size(items); self->item_count = (uint32_t)PySequence_Size(items);
if (count != (uint32_t)PySequence_Size(sort_keys)) { PyErr_SetString(PyExc_TypeError, "The sequences must have the same length."); } if (self->item_count != (uint32_t)PySequence_Size(sort_keys)) { PyErr_SetString(PyExc_TypeError, "The sequences must have the same length."); }
self->items = (char**)calloc(count, sizeof(char*)); self->items = (UChar**)calloc(self->item_count, sizeof(UChar*));
self->sort_items = (char**)calloc(count, sizeof(char*)); self->sort_items = (char**)calloc(self->item_count, sizeof(char*));
self->item_lengths = (uint32_t*)calloc(count, sizeof(uint32_t)); self->item_lengths = (int32_t*)calloc(self->item_count, sizeof(uint32_t));
self->sort_item_lengths = (uint32_t*)calloc(count, sizeof(uint32_t)); self->sort_item_lengths = (int32_t*)calloc(self->item_count, sizeof(uint32_t));
self->item_count = count; self->level1 = (UChar*)calloc(alloc_uchar(l1s), sizeof(UChar));
self->level2 = (UChar*)calloc(alloc_uchar(l2s), sizeof(UChar));
self->level3 = (UChar*)calloc(alloc_uchar(l3s), sizeof(UChar));
if (self->items == NULL || self->sort_items == NULL || self->item_lengths == NULL || self->sort_item_lengths == NULL) { if (self->items == NULL || self->sort_items == NULL || self->item_lengths == NULL || self->sort_item_lengths == NULL || self->level1 == NULL || self->level2 == NULL || self->level3 == NULL) {
PyErr_NoMemory(); goto end; PyErr_NoMemory(); goto end;
} }
u_strFromUTF8Lenient(self->level1, alloc_uchar(l1s), &i, level1, (int32_t)l1s, &status);
u_strFromUTF8Lenient(self->level2, alloc_uchar(l2s), &i, level2, (int32_t)l2s, &status);
u_strFromUTF8Lenient(self->level3, alloc_uchar(l3s), &i, level3, (int32_t)l3s, &status);
if (U_FAILURE(status)) { PyErr_SetString(PyExc_ValueError, "Failed to convert bytes for level string from UTF-8 to UTF-16"); goto end; }
for (i = 0; i < (Py_ssize_t)count; i++) { for (i = 0; i < self->item_count; i++) {
p = PySequence_Fast_GET_ITEM(self->py_items, i); p = PySequence_Fast_GET_ITEM(self->py_items, i);
self->items[i] = PyBytes_AsString(p); utf8 = PyBytes_AsString(p);
if (self->items[i] == NULL) goto end; if (utf8 == NULL) goto end;
self->item_lengths[i] = (uint32_t) PyBytes_GET_SIZE(p); cap = PyBytes_GET_SIZE(p);
self->items[i] = (UChar*)calloc(alloc_uchar(cap), sizeof(UChar));
if (self->items[i] == NULL) { PyErr_NoMemory(); goto end; }
u_strFromUTF8Lenient(self->items[i], alloc_uchar(cap), &(self->item_lengths[i]), utf8, cap, &status);
if (U_FAILURE(status)) { PyErr_SetString(PyExc_ValueError, "Failed to convert bytes from UTF-8 to UTF-16"); goto end; }
p = PySequence_Fast_GET_ITEM(self->py_sort_keys, i); p = PySequence_Fast_GET_ITEM(self->py_sort_keys, i);
self->sort_items[i] = PyBytes_AsString(p); self->sort_items[i] = PyBytes_AsString(p);
if (self->sort_items[i] == NULL) goto end; if (self->sort_items[i] == NULL) goto end;
@ -247,7 +270,7 @@ Matcher_init(Matcher *self, PyObject *args, PyObject *kwds)
} }
end: end:
if (PyErr_Occurred()) { FREE_MATCHER } if (PyErr_Occurred()) { free_matcher(self); }
return (PyErr_Occurred()) ? -1 : 0; return (PyErr_Occurred()) ? -1 : 0;
} }
// Matcher.__init__() }}} // Matcher.__init__() }}}
@ -255,14 +278,21 @@ end:
// Matcher.get_matches {{{ // Matcher.get_matches {{{
static PyObject * static PyObject *
Matcher_get_matches(Matcher *self, PyObject *args) { Matcher_get_matches(Matcher *self, PyObject *args) {
char *needle = NULL; char *cneedle = NULL;
Py_ssize_t qsize = 0; int32_t qsize = 0;
Match *matches = NULL; Match *matches = NULL;
bool ok = FALSE; bool ok = FALSE;
uint32_t i = 0; uint32_t i = 0;
PyObject *items = NULL; PyObject *items = NULL;
UErrorCode status = U_ZERO_ERROR;
UChar *needle = NULL;
if (!PyArg_ParseTuple(args, "s#", &needle, &qsize)) return NULL; if (!PyArg_ParseTuple(args, "s#", &cneedle, &qsize)) return NULL;
needle = (UChar*)calloc(alloc_uchar(qsize), sizeof(UChar));
if (needle == NULL) return PyErr_NoMemory();
u_strFromUTF8Lenient(needle, alloc_uchar(qsize), &qsize, cneedle, qsize, &status);
if (U_FAILURE(status)) { PyErr_SetString(PyExc_ValueError, "Failed to convert bytes from UTF-8 to UTF-16"); goto end; }
items = PyTuple_New(self->item_count); items = PyTuple_New(self->item_count);
matches = (Match*)calloc(self->item_count, sizeof(Match)); matches = (Match*)calloc(self->item_count, sizeof(Match));
@ -275,7 +305,7 @@ Matcher_get_matches(Matcher *self, PyObject *args) {
} }
Py_BEGIN_ALLOW_THREADS; Py_BEGIN_ALLOW_THREADS;
ok = match(self->items, self->item_lengths, self->item_count, needle, (uint32_t)qsize, matches); ok = match(self->items, self->item_lengths, self->item_count, needle, (uint32_t)qsize, matches, self->level1, self->level2, self->level3);
if (ok) qsort(matches, self->item_count, sizeof(Match), cmp_score, NULL); if (ok) qsort(matches, self->item_count, sizeof(Match), cmp_score, NULL);
Py_END_ALLOW_THREADS; Py_END_ALLOW_THREADS;
@ -287,6 +317,7 @@ Matcher_get_matches(Matcher *self, PyObject *args) {
} else { PyErr_NoMemory(); goto end; } } else { PyErr_NoMemory(); goto end; }
end: end:
nullfree(needle);
nullfree(matches); nullfree(matches);
if (PyErr_Occurred()) { Py_XDECREF(items); return NULL; } if (PyErr_Occurred()) { Py_XDECREF(items); return NULL; }
return items; return items;

View File

@ -15,7 +15,7 @@ from calibre.utils.icu import primary_sort_key
class Matcher(object): class Matcher(object):
def __init__(self, items): def __init__(self, items, level1='/', level2='-_ 0123456789', level3='.'):
items = map(lambda x: normalize('NFC', unicode(x)), filter(None, items)) items = map(lambda x: normalize('NFC', unicode(x)), filter(None, items))
items = tuple(map(lambda x: x.encode('utf-8'), items)) items = tuple(map(lambda x: x.encode('utf-8'), items))
sort_keys = tuple(map(primary_sort_key, items)) sort_keys = tuple(map(primary_sort_key, items))
@ -23,11 +23,11 @@ class Matcher(object):
speedup, err = plugins['matcher'] speedup, err = plugins['matcher']
if speedup is None: if speedup is None:
raise RuntimeError('Failed to load the matcher plugin with error: %s' % err) raise RuntimeError('Failed to load the matcher plugin with error: %s' % err)
self.m = speedup.Matcher(items, sort_keys) self.m = speedup.Matcher(items, sort_keys, level1.encode('utf-8'), level2.encode('utf-8'), level3.encode('utf-8'))
def __call__(self, query): def __call__(self, query):
query = normalize('NFC', unicode(query)).encode('utf-8') query = normalize('NFC', unicode(query)).encode('utf-8')
return self.m.get_matches(query) return map(lambda x:x.decode('utf-8'), self.m.get_matches(query))
def test_mem(): def test_mem():
from calibre.utils.mem import gc_histogram, diff_hists from calibre.utils.mem import gc_histogram, diff_hists
@ -45,4 +45,7 @@ def test_mem():
diff_hists(h1, h2) diff_hists(h1, h2)
if __name__ == '__main__': if __name__ == '__main__':
m = Matcher(['image/one.png', 'image/two.gif', 'text/one.html'])
for q in ('one', 'ton', 'imo'):
print (q, '->', tuple(m(q)))
test_mem() test_mem()