Start work on stemming for the ICU tokenizer

This commit is contained in:
Kovid Goyal 2021-06-20 14:42:01 +05:30
parent 5565c3395e
commit f5d56958b8
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 83 additions and 13 deletions

View File

@ -76,8 +76,8 @@
"headers": "calibre/utils/cpp_binding.h",
"sources": "calibre/db/sqlite_extension.cpp",
"needs_c++11": true,
"libraries": "icudata icui18n icuuc icuio",
"windows_libraries": "icudt icuin icuuc icuio",
"libraries": "icudata icui18n icuuc icuio stemmer",
"windows_libraries": "icudt icuin icuuc icuio libstemmer",
"lib_dirs": "!icu_lib_dirs",
"inc_dirs": "!icu_inc_dirs !sqlite_inc_dirs"
},

View File

@ -21,6 +21,7 @@
#include <unicode/errorcode.h>
#include <unicode/brkiter.h>
#include <unicode/uscript.h>
#include <libstemmer.h>
#include "../utils/cpp_binding.h"
SQLITE_EXTENSION_INIT1
@ -104,17 +105,53 @@ struct char_cmp {
}
};
class Stemmer {
private:
struct sb_stemmer *handle;
public:
Stemmer() : handle(NULL) {}
Stemmer(const char *lang) {
char buf[32] = {0};
size_t len = strlen(lang);
for (size_t i = 0; i < sizeof(buf) - 1 && i < len; i++) {
buf[i] = lang[i];
if ('A' <= buf[i] && buf[i] <= 'Z') buf[i] += 'a' - 'A';
}
handle = sb_stemmer_new(buf, NULL);
}
~Stemmer() {
if (handle) {
sb_stemmer_delete(handle);
handle = NULL;
}
}
const char* stem(const char *token, size_t token_sz, int *sz) {
const char *ans = NULL;
if (handle) {
ans = reinterpret_cast<const char*>(sb_stemmer_stem(handle, reinterpret_cast<const sb_symbol*>(token), (int)token_sz));
if (ans) *sz = sb_stemmer_length(handle);
}
return ans;
}
explicit operator bool() const noexcept { return handle != NULL; }
};
typedef std::unique_ptr<icu::BreakIterator> BreakIterator;
class Tokenizer {
private:
bool remove_diacritics;
bool remove_diacritics, stem_words;
std::unique_ptr<icu::Transliterator> diacritics_remover;
std::vector<int> byte_offsets;
std::string token_buf, current_ui_language;
token_callback_func current_callback;
void *current_callback_ctx;
std::map<const char*, BreakIterator, char_cmp> iterators;
std::map<const char*, Stemmer, char_cmp> stemmers;
bool is_token_char(UChar32 ch) const {
switch(u_charType(ch)) {
@ -136,9 +173,16 @@ private:
return false;
}
int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, int flags = 0) {
int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, Stemmer &stemmer, int flags = 0) {
token_buf.clear(); token_buf.reserve(4 * token.length());
token.toUTF8String(token_buf);
const char *root = token_buf.c_str(); int sz = (int)token_buf.size();
if (stem_words && stemmer) {
root = stemmer.stem(root, sz, &sz);
if (!root) {
root = token_buf.c_str(); sz = (int)token_buf.size();
}
}
return current_callback(current_callback_ctx, flags, token_buf.c_str(), (int)token_buf.size(), byte_offsets.at(start_offset), byte_offsets.at(end_offset));
}
@ -203,7 +247,14 @@ private:
return ans->second;
}
int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator) {
Stemmer& ensure_stemmer(const char *lang = "") {
if (!lang[0]) lang = current_ui_language.c_str();
auto ans = stemmers.find(lang);
if (ans == stemmers.end()) stemmers[lang] = stem_words ? Stemmer(lang) : Stemmer();
return ans->second;
}
int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, Stemmer &stemmer) {
word_iterator->setText(str.tempSubStringBetween(block_start, block_limit));
int32_t token_start_pos = word_iterator->first() + block_start, token_end_pos;
int rc = SQLITE_OK;
@ -219,12 +270,12 @@ private:
if (is_token) {
icu::UnicodeString token(str, token_start_pos, token_end_pos - token_start_pos);
token.foldCase(U_FOLD_CASE_DEFAULT);
if ((rc = send_token(token, token_start_pos, token_end_pos)) != SQLITE_OK) return rc;
if ((rc = send_token(token, token_start_pos, token_end_pos, stemmer)) != SQLITE_OK) return rc;
if (!for_query && remove_diacritics) {
icu::UnicodeString tt(token);
diacritics_remover->transliterate(tt);
if (tt != token) {
if ((rc = send_token(tt, token_start_pos, token_end_pos, FTS5_TOKEN_COLOCATED)) != SQLITE_OK) return rc;
if ((rc = send_token(tt, token_start_pos, token_end_pos, stemmer, FTS5_TOKEN_COLOCATED)) != SQLITE_OK) return rc;
}
}
}
@ -236,8 +287,9 @@ private:
public:
int constructor_error;
Tokenizer(const char **args, int nargs) :
remove_diacritics(true), diacritics_remover(),
Tokenizer(const char **args, int nargs, bool stem_words = false) :
remove_diacritics(true), stem_words(stem_words), diacritics_remover(),
byte_offsets(), token_buf(), current_ui_language(""),
current_callback(NULL), current_callback_ctx(NULL), iterators(),
@ -248,6 +300,11 @@ public:
i++;
if (i < nargs && strcmp(args[i], "0") == 0) remove_diacritics = false;
}
else if (strcmp(args[i], "stem_words") == 0) {
i++;
if (i < nargs && strcmp(args[i], "0") == 0) stem_words = false;
else stem_words = true;
}
}
if (remove_diacritics) {
icu::ErrorCode status;
@ -277,21 +334,23 @@ public:
state.language = ""; state.script = USCRIPT_COMMON;
int32_t start_script_block_at = offset;
auto word_iterator = std::ref(ensure_lang_iterator(state.language));
auto stemmer = std::ref(ensure_stemmer(state.language));
while (offset < str.length()) {
UChar32 ch = str.char32At(offset);
if (at_script_boundary(state, ch)) {
if (offset > start_script_block_at) {
if ((rc = tokenize_script_block(
str, start_script_block_at, offset,
for_query, callback, callback_ctx, word_iterator)) != SQLITE_OK) return rc;
for_query, callback, callback_ctx, word_iterator, stemmer)) != SQLITE_OK) return rc;
}
start_script_block_at = offset;
word_iterator = ensure_lang_iterator(state.language);
stemmer = ensure_stemmer(state.language);
}
offset = str.moveIndex32(offset, 1);
}
if (offset > start_script_block_at) {
rc = tokenize_script_block(str, start_script_block_at, offset, for_query, callback, callback_ctx, word_iterator);
rc = tokenize_script_block(str, start_script_block_at, offset, for_query, callback, callback_ctx, word_iterator, stemmer);
}
return rc;
}
@ -312,10 +371,10 @@ fts5_api_from_db(sqlite3 *db, fts5_api **ppApi) {
}
static int
tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) {
_tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut, bool stem_words = false) {
int rc = SQLITE_OK;
try {
Tokenizer *p = new Tokenizer(azArg, nArg);
Tokenizer *p = new Tokenizer(azArg, nArg, stem_words);
*ppOut = reinterpret_cast<Fts5Tokenizer *>(p);
if (p->constructor_error != SQLITE_OK) {
rc = p->constructor_error;
@ -329,6 +388,12 @@ tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) {
return rc;
}
static int
tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { return _tok_create(sqlite3, azArg, nArg, ppOut); }
static int
tok_create_with_stemming(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { return _tok_create(sqlite3, azArg, nArg, ppOut, true); }
static int
tok_tokenize(Fts5Tokenizer *tokenizer_ptr, void *callback_ctx, int flags, const char *text, int text_sz, token_callback_func callback) {
Tokenizer *p = reinterpret_cast<Tokenizer*>(tokenizer_ptr);
@ -370,6 +435,9 @@ calibre_sqlite_extension_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_ro
}
fts5_tokenizer tok = {tok_create, tok_delete, tok_tokenize};
fts5api->xCreateTokenizer(fts5api, "unicode61", reinterpret_cast<void *>(fts5api), &tok, NULL);
fts5api->xCreateTokenizer(fts5api, "calibre", reinterpret_cast<void *>(fts5api), &tok, NULL);
fts5_tokenizer tok2 = {tok_create_with_stemming, tok_delete, tok_tokenize};
fts5api->xCreateTokenizer(fts5api, "porter", reinterpret_cast<void *>(fts5api), &tok2, NULL);
return SQLITE_OK;
}
}

View File

@ -478,6 +478,8 @@ def find_tests():
ans.addTests(find_tests())
from calibre.spell.dictionary import find_tests
ans.addTests(find_tests())
from calibre.db.tests.fts import find_tests
ans.addTests(find_tests())
return ans