diff --git a/src/calibre/db/sqlite_extension.cpp b/src/calibre/db/sqlite_extension.cpp index e31ae5e843..4a66344849 100644 --- a/src/calibre/db/sqlite_extension.cpp +++ b/src/calibre/db/sqlite_extension.cpp @@ -14,6 +14,8 @@ #include #include #include +#include +#include SQLITE_EXTENSION_INIT1 typedef int (*token_callback_func)(void *, int, const char *, int, int, int); @@ -80,11 +82,11 @@ populate_icu_string(const char *text, int text_sz, icu::UnicodeString &str, std: class Tokenizer { private: - bool remove_diacritics; + icu::Transliterator *diacritics_remover; std::vector byte_offsets; + std::string token_buf; token_callback_func current_callback; void *current_callback_ctx; - std::string token_buf; bool is_token_char(UChar32 ch) const { switch(u_charType(ch)) { @@ -103,18 +105,38 @@ private: } } - int send_token(int32_t start_offset, int32_t end_offset, int flags = 0) { + int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, int flags = 0) { + token_buf.clear(); token_buf.reserve(4 * token.length()); + token.toUTF8String(token_buf); return current_callback(current_callback_ctx, flags, token_buf.c_str(), token_buf.size(), byte_offsets[start_offset], byte_offsets[end_offset]); } public: - Tokenizer(const char **args, int nargs) : remove_diacritics(true), byte_offsets(), token_buf() { + int constructor_error; + Tokenizer(const char **args, int nargs) : + diacritics_remover(NULL), + byte_offsets(), token_buf(), + current_callback(NULL), current_callback_ctx(NULL), constructor_error(SQLITE_OK) + { + bool remove_diacritics = true; for (int i = 0; i < nargs; i++) { if (strcmp(args[i], "remove_diacritics") == 0) { i++; if (i < nargs && strcmp(args[i], "0") == 0) remove_diacritics = false; } } + if (remove_diacritics) { + icu::ErrorCode status; + diacritics_remover = icu::Transliterator::createInstance("NFD; [:M:] Remove; NFC", UTRANS_FORWARD, status); + if (status.isFailure()) { + fprintf(stderr, "Failed to create ICU transliterator to remove diacritics with error: %s\n", status.errorName()); + constructor_error = SQLITE_INTERNAL; + } + } + } + ~Tokenizer() { + if (diacritics_remover) icu::Transliterator::unregister(diacritics_remover->getID()); + diacritics_remover = NULL; } int tokenize(void *callback_ctx, int flags, const char *text, int text_sz, token_callback_func callback) { @@ -125,6 +147,8 @@ public: populate_icu_string(text, text_sz, str, byte_offsets); str.foldCase(U_FOLD_CASE_DEFAULT); int32_t offset = str.getChar32Start(0); + int rc; + bool for_query = (flags & FTS5_TOKENIZE_QUERY) != 0; while (offset < str.length()) { // soak up non-token chars while (offset < str.length() && !is_token_char(str.char32At(offset))) offset = str.moveIndex32(offset, 1); @@ -135,10 +159,14 @@ public: if (offset > start_offset) { icu::UnicodeString token(str, start_offset, offset - start_offset); token.foldCase(U_FOLD_CASE_DEFAULT); - token_buf.clear(); token_buf.reserve(4 * (offset - start_offset)); - token.toUTF8String(token_buf); - int rc = send_token(start_offset, offset); - if (rc != SQLITE_OK) return rc; + if ((rc = send_token(token, start_offset, offset)) != SQLITE_OK) return rc; + if (!for_query && diacritics_remover) { + icu::UnicodeString tt(token); + diacritics_remover->transliterate(tt); + if (tt != token) { + if ((rc = send_token(tt, start_offset, offset, FTS5_TOKEN_COLOCATED)) != SQLITE_OK) return rc; + } + } } } return SQLITE_OK; @@ -161,15 +189,20 @@ fts5_api_from_db(sqlite3 *db, fts5_api **ppApi) { static int tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { + int rc = SQLITE_OK; try { Tokenizer *p = new Tokenizer(azArg, nArg); *ppOut = reinterpret_cast(p); + if (p->constructor_error != SQLITE_OK) { + rc = p->constructor_error; + delete p; + } } catch (std::bad_alloc &ex) { return SQLITE_NOMEM; } catch (...) { return SQLITE_ERROR; } - return SQLITE_OK; + return rc; } static int diff --git a/src/calibre/db/tests/fts.py b/src/calibre/db/tests/fts.py index 322d4961cb..cd31cabb0c 100644 --- a/src/calibre/db/tests/fts.py +++ b/src/calibre/db/tests/fts.py @@ -45,6 +45,9 @@ class FTSTest(BaseTest): def test_basic_fts(self): # {{{ conn = TestConn() conn.insert_text('two words, and a period. With another.') - conn.insert_text('and another') - self.ae(conn.term_row_counts(), {'a': 1, 'and': 2, 'another': 2, 'period': 1, 'two': 1, 'with': 1, 'words': 1}) + conn.insert_text('and another re-init') + self.ae(conn.term_row_counts(), {'a': 1, 're': 1, 'init': 1, 'and': 2, 'another': 2, 'period': 1, 'two': 1, 'with': 1, 'words': 1}) + conn = TestConn() + conn.insert_text('coộl') + self.ae(conn.term_row_counts(), {'cool': 1, 'coộl': 1}) # }}}