From f5d56958b813bdf828101ce15e205fe37829f5b9 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 20 Jun 2021 14:42:01 +0530 Subject: [PATCH] Start work on stemming for the ICU tokenizer --- setup/extensions.json | 4 +- src/calibre/db/sqlite_extension.cpp | 90 +++++++++++++++++++++++++---- src/calibre/test_build.py | 2 + 3 files changed, 83 insertions(+), 13 deletions(-) diff --git a/setup/extensions.json b/setup/extensions.json index dd16442ec0..4abbde7edd 100644 --- a/setup/extensions.json +++ b/setup/extensions.json @@ -76,8 +76,8 @@ "headers": "calibre/utils/cpp_binding.h", "sources": "calibre/db/sqlite_extension.cpp", "needs_c++11": true, - "libraries": "icudata icui18n icuuc icuio", - "windows_libraries": "icudt icuin icuuc icuio", + "libraries": "icudata icui18n icuuc icuio stemmer", + "windows_libraries": "icudt icuin icuuc icuio libstemmer", "lib_dirs": "!icu_lib_dirs", "inc_dirs": "!icu_inc_dirs !sqlite_inc_dirs" }, diff --git a/src/calibre/db/sqlite_extension.cpp b/src/calibre/db/sqlite_extension.cpp index 0c362ecf50..e84ca150b1 100644 --- a/src/calibre/db/sqlite_extension.cpp +++ b/src/calibre/db/sqlite_extension.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include "../utils/cpp_binding.h" SQLITE_EXTENSION_INIT1 @@ -104,17 +105,53 @@ struct char_cmp { } }; +class Stemmer { +private: + struct sb_stemmer *handle; + +public: + Stemmer() : handle(NULL) {} + Stemmer(const char *lang) { + char buf[32] = {0}; + size_t len = strlen(lang); + for (size_t i = 0; i < sizeof(buf) - 1 && i < len; i++) { + buf[i] = lang[i]; + if ('A' <= buf[i] && buf[i] <= 'Z') buf[i] += 'a' - 'A'; + } + handle = sb_stemmer_new(buf, NULL); + } + + ~Stemmer() { + if (handle) { + sb_stemmer_delete(handle); + handle = NULL; + } + } + + const char* stem(const char *token, size_t token_sz, int *sz) { + const char *ans = NULL; + if (handle) { + ans = reinterpret_cast(sb_stemmer_stem(handle, reinterpret_cast(token), (int)token_sz)); + if (ans) *sz = sb_stemmer_length(handle); + } + return ans; + } + + explicit operator bool() const noexcept { return handle != NULL; } +}; + typedef std::unique_ptr BreakIterator; class Tokenizer { private: - bool remove_diacritics; + bool remove_diacritics, stem_words; std::unique_ptr diacritics_remover; std::vector byte_offsets; std::string token_buf, current_ui_language; token_callback_func current_callback; void *current_callback_ctx; std::map iterators; + std::map stemmers; bool is_token_char(UChar32 ch) const { switch(u_charType(ch)) { @@ -136,9 +173,16 @@ private: return false; } - int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, int flags = 0) { + int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, Stemmer &stemmer, int flags = 0) { token_buf.clear(); token_buf.reserve(4 * token.length()); token.toUTF8String(token_buf); + const char *root = token_buf.c_str(); int sz = (int)token_buf.size(); + if (stem_words && stemmer) { + root = stemmer.stem(root, sz, &sz); + if (!root) { + root = token_buf.c_str(); sz = (int)token_buf.size(); + } + } return current_callback(current_callback_ctx, flags, token_buf.c_str(), (int)token_buf.size(), byte_offsets.at(start_offset), byte_offsets.at(end_offset)); } @@ -203,7 +247,14 @@ private: return ans->second; } - int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator) { + Stemmer& ensure_stemmer(const char *lang = "") { + if (!lang[0]) lang = current_ui_language.c_str(); + auto ans = stemmers.find(lang); + if (ans == stemmers.end()) stemmers[lang] = stem_words ? Stemmer(lang) : Stemmer(); + return ans->second; + } + + int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, Stemmer &stemmer) { word_iterator->setText(str.tempSubStringBetween(block_start, block_limit)); int32_t token_start_pos = word_iterator->first() + block_start, token_end_pos; int rc = SQLITE_OK; @@ -219,12 +270,12 @@ private: if (is_token) { icu::UnicodeString token(str, token_start_pos, token_end_pos - token_start_pos); token.foldCase(U_FOLD_CASE_DEFAULT); - if ((rc = send_token(token, token_start_pos, token_end_pos)) != SQLITE_OK) return rc; + if ((rc = send_token(token, token_start_pos, token_end_pos, stemmer)) != SQLITE_OK) return rc; if (!for_query && remove_diacritics) { icu::UnicodeString tt(token); diacritics_remover->transliterate(tt); if (tt != token) { - if ((rc = send_token(tt, token_start_pos, token_end_pos, FTS5_TOKEN_COLOCATED)) != SQLITE_OK) return rc; + if ((rc = send_token(tt, token_start_pos, token_end_pos, stemmer, FTS5_TOKEN_COLOCATED)) != SQLITE_OK) return rc; } } } @@ -236,8 +287,9 @@ private: public: int constructor_error; - Tokenizer(const char **args, int nargs) : - remove_diacritics(true), diacritics_remover(), + + Tokenizer(const char **args, int nargs, bool stem_words = false) : + remove_diacritics(true), stem_words(stem_words), diacritics_remover(), byte_offsets(), token_buf(), current_ui_language(""), current_callback(NULL), current_callback_ctx(NULL), iterators(), @@ -248,6 +300,11 @@ public: i++; if (i < nargs && strcmp(args[i], "0") == 0) remove_diacritics = false; } + else if (strcmp(args[i], "stem_words") == 0) { + i++; + if (i < nargs && strcmp(args[i], "0") == 0) stem_words = false; + else stem_words = true; + } } if (remove_diacritics) { icu::ErrorCode status; @@ -277,21 +334,23 @@ public: state.language = ""; state.script = USCRIPT_COMMON; int32_t start_script_block_at = offset; auto word_iterator = std::ref(ensure_lang_iterator(state.language)); + auto stemmer = std::ref(ensure_stemmer(state.language)); while (offset < str.length()) { UChar32 ch = str.char32At(offset); if (at_script_boundary(state, ch)) { if (offset > start_script_block_at) { if ((rc = tokenize_script_block( str, start_script_block_at, offset, - for_query, callback, callback_ctx, word_iterator)) != SQLITE_OK) return rc; + for_query, callback, callback_ctx, word_iterator, stemmer)) != SQLITE_OK) return rc; } start_script_block_at = offset; word_iterator = ensure_lang_iterator(state.language); + stemmer = ensure_stemmer(state.language); } offset = str.moveIndex32(offset, 1); } if (offset > start_script_block_at) { - rc = tokenize_script_block(str, start_script_block_at, offset, for_query, callback, callback_ctx, word_iterator); + rc = tokenize_script_block(str, start_script_block_at, offset, for_query, callback, callback_ctx, word_iterator, stemmer); } return rc; } @@ -312,10 +371,10 @@ fts5_api_from_db(sqlite3 *db, fts5_api **ppApi) { } static int -tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { +_tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut, bool stem_words = false) { int rc = SQLITE_OK; try { - Tokenizer *p = new Tokenizer(azArg, nArg); + Tokenizer *p = new Tokenizer(azArg, nArg, stem_words); *ppOut = reinterpret_cast(p); if (p->constructor_error != SQLITE_OK) { rc = p->constructor_error; @@ -329,6 +388,12 @@ tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { return rc; } +static int +tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { return _tok_create(sqlite3, azArg, nArg, ppOut); } + +static int +tok_create_with_stemming(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { return _tok_create(sqlite3, azArg, nArg, ppOut, true); } + static int tok_tokenize(Fts5Tokenizer *tokenizer_ptr, void *callback_ctx, int flags, const char *text, int text_sz, token_callback_func callback) { Tokenizer *p = reinterpret_cast(tokenizer_ptr); @@ -370,6 +435,9 @@ calibre_sqlite_extension_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_ro } fts5_tokenizer tok = {tok_create, tok_delete, tok_tokenize}; fts5api->xCreateTokenizer(fts5api, "unicode61", reinterpret_cast(fts5api), &tok, NULL); + fts5api->xCreateTokenizer(fts5api, "calibre", reinterpret_cast(fts5api), &tok, NULL); + fts5_tokenizer tok2 = {tok_create_with_stemming, tok_delete, tok_tokenize}; + fts5api->xCreateTokenizer(fts5api, "porter", reinterpret_cast(fts5api), &tok2, NULL); return SQLITE_OK; } } diff --git a/src/calibre/test_build.py b/src/calibre/test_build.py index 5c58ff5431..314cdf1275 100644 --- a/src/calibre/test_build.py +++ b/src/calibre/test_build.py @@ -478,6 +478,8 @@ def find_tests(): ans.addTests(find_tests()) from calibre.spell.dictionary import find_tests ans.addTests(find_tests()) + from calibre.db.tests.fts import find_tests + ans.addTests(find_tests()) return ans