Add tests for stemming

This commit is contained in:
Kovid Goyal 2021-06-21 11:48:22 +05:30
parent 2bfc3d1e7f
commit 7fe5fff311
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 59 additions and 11 deletions

View File

@ -141,6 +141,7 @@ public:
};
typedef std::unique_ptr<icu::BreakIterator> BreakIterator;
typedef std::unique_ptr<Stemmer> StemmerPtr;
class Tokenizer {
private:
@ -151,7 +152,7 @@ private:
token_callback_func current_callback;
void *current_callback_ctx;
std::map<const char*, BreakIterator, char_cmp> iterators;
std::map<const char*, Stemmer, char_cmp> stemmers;
std::map<const char*, StemmerPtr, char_cmp> stemmers;
bool is_token_char(UChar32 ch) const {
switch(u_charType(ch)) {
@ -173,17 +174,18 @@ private:
return false;
}
int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, Stemmer &stemmer, int flags = 0) {
int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, StemmerPtr &stemmer, int flags = 0) {
token_buf.clear(); token_buf.reserve(4 * token.length());
token.toUTF8String(token_buf);
const char *root = token_buf.c_str(); int sz = (int)token_buf.size();
if (stem_words && stemmer) {
root = stemmer.stem(root, sz, &sz);
if (stem_words && stemmer->operator bool()) {
root = stemmer->stem(root, sz, &sz);
if (!root) {
root = token_buf.c_str(); sz = (int)token_buf.size();
root = token_buf.c_str();
sz = (int)token_buf.size();
}
}
return current_callback(current_callback_ctx, flags, token_buf.c_str(), (int)token_buf.size(), byte_offsets.at(start_offset), byte_offsets.at(end_offset));
return current_callback(current_callback_ctx, flags, root, (int)sz, byte_offsets.at(start_offset), byte_offsets.at(end_offset));
}
const char* iterator_language_for_script(UScriptCode script) const {
@ -247,14 +249,18 @@ private:
return ans->second;
}
Stemmer& ensure_stemmer(const char *lang = "") {
StemmerPtr& ensure_stemmer(const char *lang = "") {
if (!lang[0]) lang = current_ui_language.c_str();
auto ans = stemmers.find(lang);
if (ans == stemmers.end()) stemmers[lang] = stem_words ? Stemmer(lang) : Stemmer();
if (ans == stemmers.end()) {
if (stem_words) stemmers[lang] = StemmerPtr(new Stemmer(lang));
else stemmers[lang] = StemmerPtr(new Stemmer());
ans = stemmers.find(lang);
}
return ans->second;
}
int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, Stemmer &stemmer) {
int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, StemmerPtr &stemmer) {
word_iterator->setText(str.tempSubStringBetween(block_start, block_limit));
int32_t token_start_pos = word_iterator->first() + block_start, token_end_pos;
int rc = SQLITE_OK;
@ -493,6 +499,21 @@ tokenize(PyObject *self, PyObject *args) {
return ans.detach();
}
static PyObject*
stem(PyObject *self, PyObject *args) {
const char *text, *lang = "en"; int text_length;
if (!PyArg_ParseTuple(args, "s#|s", &text, &text_length, &lang)) return NULL;
Stemmer s(lang);
if (!s) {
PyErr_SetString(PyExc_ValueError, "No stemmer for the specified language");
return NULL;
}
int sz;
const char* result = s.stem(text, text_length, &sz);
if (!result) return PyErr_NoMemory();
return Py_BuildValue("s#", result, sz);
}
static PyMethodDef methods[] = {
{"get_locales_for_break_iteration", get_locales_for_break_iteration, METH_NOARGS,
"Get list of available locales for break iteration"
@ -503,6 +524,12 @@ static PyMethodDef methods[] = {
{"tokenize", tokenize, METH_VARARGS,
"Tokenize a string, useful for testing"
},
{"tokenize", tokenize, METH_VARARGS,
"Tokenize a string, useful for testing"
},
{"stem", stem, METH_VARARGS,
"Stem a word in the specified language, defaulting to English"
},
{NULL, NULL, 0, NULL}
};

View File

@ -19,7 +19,7 @@ def print(*args, **kwargs):
class TestConn(Connection):
def __init__(self, remove_diacritics=True, language='en'):
def __init__(self, remove_diacritics=True, language='en', stem_words=False):
from calibre_extensions.sqlite_extension import set_ui_language
set_ui_language(language)
super().__init__(':memory:')
@ -27,8 +27,9 @@ class TestConn(Connection):
options = []
options.append('remove_diacritics'), options.append('2' if remove_diacritics else '0')
options = ' '.join(options)
tok = 'porter ' if stem_words else ''
self.execute(f'''
CREATE VIRTUAL TABLE fts_table USING fts5(t, tokenize = 'unicode61 {options}');
CREATE VIRTUAL TABLE fts_table USING fts5(t, tokenize = '{tok}unicode61 {options}');
CREATE VIRTUAL TABLE fts_row USING fts5vocab(fts_table, row);
''')
@ -139,6 +140,26 @@ class FTSTest(BaseTest):
self.ae(conn.search(""), [("你don't>叫<mess",)])
# }}}
def test_fts_stemming(self): # {{{
from calibre_extensions.sqlite_extension import stem
self.ae(stem('run'), 'run')
self.ae(stem('connection'), 'connect')
self.ae(stem('maintenaient'), 'maintenai')
self.ae(stem('maintenaient', 'fr'), 'mainten')
self.ae(stem('continué', 'fr'), 'continu')
self.ae(stem('maître', 'FRA'), 'maîtr')
conn = TestConn(stem_words=True)
conn.insert_text('a simplistic connection')
self.ae(conn.term_row_counts(), {'a': 1, 'connect': 1, 'simplist': 1})
self.ae(conn.search("connection"), [('a simplistic >connection<',),])
self.ae(conn.search("connect"), [('a simplistic >connection<',),])
self.ae(conn.search("simplistic connect"), [('a >simplistic< >connection<',),])
self.ae(conn.search("simplist"), [('a >simplistic< connection',),])
# }}}
def test_fts_query_syntax(self): # {{{
conn = TestConn()
conn.insert_text('one two three')