From 7fe5fff3110cc43e6846ed53d36aecdea0d8f06e Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 21 Jun 2021 11:48:22 +0530 Subject: [PATCH] Add tests for stemming --- src/calibre/db/sqlite_extension.cpp | 45 +++++++++++++++++++++++------ src/calibre/db/tests/fts.py | 25 ++++++++++++++-- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/src/calibre/db/sqlite_extension.cpp b/src/calibre/db/sqlite_extension.cpp index e84ca150b1..1cff371c35 100644 --- a/src/calibre/db/sqlite_extension.cpp +++ b/src/calibre/db/sqlite_extension.cpp @@ -141,6 +141,7 @@ public: }; typedef std::unique_ptr BreakIterator; +typedef std::unique_ptr StemmerPtr; class Tokenizer { private: @@ -151,7 +152,7 @@ private: token_callback_func current_callback; void *current_callback_ctx; std::map iterators; - std::map stemmers; + std::map stemmers; bool is_token_char(UChar32 ch) const { switch(u_charType(ch)) { @@ -173,17 +174,18 @@ private: return false; } - int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, Stemmer &stemmer, int flags = 0) { + int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, StemmerPtr &stemmer, int flags = 0) { token_buf.clear(); token_buf.reserve(4 * token.length()); token.toUTF8String(token_buf); const char *root = token_buf.c_str(); int sz = (int)token_buf.size(); - if (stem_words && stemmer) { - root = stemmer.stem(root, sz, &sz); + if (stem_words && stemmer->operator bool()) { + root = stemmer->stem(root, sz, &sz); if (!root) { - root = token_buf.c_str(); sz = (int)token_buf.size(); + root = token_buf.c_str(); + sz = (int)token_buf.size(); } } - return current_callback(current_callback_ctx, flags, token_buf.c_str(), (int)token_buf.size(), byte_offsets.at(start_offset), byte_offsets.at(end_offset)); + return current_callback(current_callback_ctx, flags, root, (int)sz, byte_offsets.at(start_offset), byte_offsets.at(end_offset)); } const char* iterator_language_for_script(UScriptCode script) const { @@ -247,14 +249,18 @@ private: return ans->second; } - Stemmer& ensure_stemmer(const char *lang = "") { + StemmerPtr& ensure_stemmer(const char *lang = "") { if (!lang[0]) lang = current_ui_language.c_str(); auto ans = stemmers.find(lang); - if (ans == stemmers.end()) stemmers[lang] = stem_words ? Stemmer(lang) : Stemmer(); + if (ans == stemmers.end()) { + if (stem_words) stemmers[lang] = StemmerPtr(new Stemmer(lang)); + else stemmers[lang] = StemmerPtr(new Stemmer()); + ans = stemmers.find(lang); + } return ans->second; } - int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, Stemmer &stemmer) { + int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, StemmerPtr &stemmer) { word_iterator->setText(str.tempSubStringBetween(block_start, block_limit)); int32_t token_start_pos = word_iterator->first() + block_start, token_end_pos; int rc = SQLITE_OK; @@ -493,6 +499,21 @@ tokenize(PyObject *self, PyObject *args) { return ans.detach(); } +static PyObject* +stem(PyObject *self, PyObject *args) { + const char *text, *lang = "en"; int text_length; + if (!PyArg_ParseTuple(args, "s#|s", &text, &text_length, &lang)) return NULL; + Stemmer s(lang); + if (!s) { + PyErr_SetString(PyExc_ValueError, "No stemmer for the specified language"); + return NULL; + } + int sz; + const char* result = s.stem(text, text_length, &sz); + if (!result) return PyErr_NoMemory(); + return Py_BuildValue("s#", result, sz); +} + static PyMethodDef methods[] = { {"get_locales_for_break_iteration", get_locales_for_break_iteration, METH_NOARGS, "Get list of available locales for break iteration" @@ -503,6 +524,12 @@ static PyMethodDef methods[] = { {"tokenize", tokenize, METH_VARARGS, "Tokenize a string, useful for testing" }, + {"tokenize", tokenize, METH_VARARGS, + "Tokenize a string, useful for testing" + }, + {"stem", stem, METH_VARARGS, + "Stem a word in the specified language, defaulting to English" + }, {NULL, NULL, 0, NULL} }; diff --git a/src/calibre/db/tests/fts.py b/src/calibre/db/tests/fts.py index 63460ca9c4..7f777b7737 100644 --- a/src/calibre/db/tests/fts.py +++ b/src/calibre/db/tests/fts.py @@ -19,7 +19,7 @@ def print(*args, **kwargs): class TestConn(Connection): - def __init__(self, remove_diacritics=True, language='en'): + def __init__(self, remove_diacritics=True, language='en', stem_words=False): from calibre_extensions.sqlite_extension import set_ui_language set_ui_language(language) super().__init__(':memory:') @@ -27,8 +27,9 @@ class TestConn(Connection): options = [] options.append('remove_diacritics'), options.append('2' if remove_diacritics else '0') options = ' '.join(options) + tok = 'porter ' if stem_words else '' self.execute(f''' -CREATE VIRTUAL TABLE fts_table USING fts5(t, tokenize = 'unicode61 {options}'); +CREATE VIRTUAL TABLE fts_table USING fts5(t, tokenize = '{tok}unicode61 {options}'); CREATE VIRTUAL TABLE fts_row USING fts5vocab(fts_table, row); ''') @@ -139,6 +140,26 @@ class FTSTest(BaseTest): self.ae(conn.search("叫"), [("你don't>叫connection<',),]) + self.ae(conn.search("connect"), [('a simplistic >connection<',),]) + self.ae(conn.search("simplistic connect"), [('a >simplistic< >connection<',),]) + self.ae(conn.search("simplist"), [('a >simplistic< connection',),]) + + # }}} + def test_fts_query_syntax(self): # {{{ conn = TestConn() conn.insert_text('one two three')