Implement diacritics removal in the new tokenizer

This commit is contained in:
Kovid Goyal 2021-06-16 14:54:15 +05:30
parent ab313c836f
commit bbee5b0acb
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 47 additions and 11 deletions

View File

@ -14,6 +14,8 @@
#include <sqlite3ext.h> #include <sqlite3ext.h>
#include <unicode/unistr.h> #include <unicode/unistr.h>
#include <unicode/uchar.h> #include <unicode/uchar.h>
#include <unicode/translit.h>
#include <unicode/errorcode.h>
SQLITE_EXTENSION_INIT1 SQLITE_EXTENSION_INIT1
typedef int (*token_callback_func)(void *, int, const char *, int, int, int); typedef int (*token_callback_func)(void *, int, const char *, int, int, int);
@ -80,11 +82,11 @@ populate_icu_string(const char *text, int text_sz, icu::UnicodeString &str, std:
class Tokenizer { class Tokenizer {
private: private:
bool remove_diacritics; icu::Transliterator *diacritics_remover;
std::vector<int> byte_offsets; std::vector<int> byte_offsets;
std::string token_buf;
token_callback_func current_callback; token_callback_func current_callback;
void *current_callback_ctx; void *current_callback_ctx;
std::string token_buf;
bool is_token_char(UChar32 ch) const { bool is_token_char(UChar32 ch) const {
switch(u_charType(ch)) { switch(u_charType(ch)) {
@ -103,18 +105,38 @@ private:
} }
} }
int send_token(int32_t start_offset, int32_t end_offset, int flags = 0) { int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, int flags = 0) {
token_buf.clear(); token_buf.reserve(4 * token.length());
token.toUTF8String(token_buf);
return current_callback(current_callback_ctx, flags, token_buf.c_str(), token_buf.size(), byte_offsets[start_offset], byte_offsets[end_offset]); return current_callback(current_callback_ctx, flags, token_buf.c_str(), token_buf.size(), byte_offsets[start_offset], byte_offsets[end_offset]);
} }
public: public:
Tokenizer(const char **args, int nargs) : remove_diacritics(true), byte_offsets(), token_buf() { int constructor_error;
Tokenizer(const char **args, int nargs) :
diacritics_remover(NULL),
byte_offsets(), token_buf(),
current_callback(NULL), current_callback_ctx(NULL), constructor_error(SQLITE_OK)
{
bool remove_diacritics = true;
for (int i = 0; i < nargs; i++) { for (int i = 0; i < nargs; i++) {
if (strcmp(args[i], "remove_diacritics") == 0) { if (strcmp(args[i], "remove_diacritics") == 0) {
i++; i++;
if (i < nargs && strcmp(args[i], "0") == 0) remove_diacritics = false; if (i < nargs && strcmp(args[i], "0") == 0) remove_diacritics = false;
} }
} }
if (remove_diacritics) {
icu::ErrorCode status;
diacritics_remover = icu::Transliterator::createInstance("NFD; [:M:] Remove; NFC", UTRANS_FORWARD, status);
if (status.isFailure()) {
fprintf(stderr, "Failed to create ICU transliterator to remove diacritics with error: %s\n", status.errorName());
constructor_error = SQLITE_INTERNAL;
}
}
}
~Tokenizer() {
if (diacritics_remover) icu::Transliterator::unregister(diacritics_remover->getID());
diacritics_remover = NULL;
} }
int tokenize(void *callback_ctx, int flags, const char *text, int text_sz, token_callback_func callback) { int tokenize(void *callback_ctx, int flags, const char *text, int text_sz, token_callback_func callback) {
@ -125,6 +147,8 @@ public:
populate_icu_string(text, text_sz, str, byte_offsets); populate_icu_string(text, text_sz, str, byte_offsets);
str.foldCase(U_FOLD_CASE_DEFAULT); str.foldCase(U_FOLD_CASE_DEFAULT);
int32_t offset = str.getChar32Start(0); int32_t offset = str.getChar32Start(0);
int rc;
bool for_query = (flags & FTS5_TOKENIZE_QUERY) != 0;
while (offset < str.length()) { while (offset < str.length()) {
// soak up non-token chars // soak up non-token chars
while (offset < str.length() && !is_token_char(str.char32At(offset))) offset = str.moveIndex32(offset, 1); while (offset < str.length() && !is_token_char(str.char32At(offset))) offset = str.moveIndex32(offset, 1);
@ -135,10 +159,14 @@ public:
if (offset > start_offset) { if (offset > start_offset) {
icu::UnicodeString token(str, start_offset, offset - start_offset); icu::UnicodeString token(str, start_offset, offset - start_offset);
token.foldCase(U_FOLD_CASE_DEFAULT); token.foldCase(U_FOLD_CASE_DEFAULT);
token_buf.clear(); token_buf.reserve(4 * (offset - start_offset)); if ((rc = send_token(token, start_offset, offset)) != SQLITE_OK) return rc;
token.toUTF8String(token_buf); if (!for_query && diacritics_remover) {
int rc = send_token(start_offset, offset); icu::UnicodeString tt(token);
if (rc != SQLITE_OK) return rc; diacritics_remover->transliterate(tt);
if (tt != token) {
if ((rc = send_token(tt, start_offset, offset, FTS5_TOKEN_COLOCATED)) != SQLITE_OK) return rc;
}
}
} }
} }
return SQLITE_OK; return SQLITE_OK;
@ -161,15 +189,20 @@ fts5_api_from_db(sqlite3 *db, fts5_api **ppApi) {
static int static int
tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) { tok_create(void *sqlite3, const char **azArg, int nArg, Fts5Tokenizer **ppOut) {
int rc = SQLITE_OK;
try { try {
Tokenizer *p = new Tokenizer(azArg, nArg); Tokenizer *p = new Tokenizer(azArg, nArg);
*ppOut = reinterpret_cast<Fts5Tokenizer *>(p); *ppOut = reinterpret_cast<Fts5Tokenizer *>(p);
if (p->constructor_error != SQLITE_OK) {
rc = p->constructor_error;
delete p;
}
} catch (std::bad_alloc &ex) { } catch (std::bad_alloc &ex) {
return SQLITE_NOMEM; return SQLITE_NOMEM;
} catch (...) { } catch (...) {
return SQLITE_ERROR; return SQLITE_ERROR;
} }
return SQLITE_OK; return rc;
} }
static int static int

View File

@ -45,6 +45,9 @@ class FTSTest(BaseTest):
def test_basic_fts(self): # {{{ def test_basic_fts(self): # {{{
conn = TestConn() conn = TestConn()
conn.insert_text('two words, and a period. With another.') conn.insert_text('two words, and a period. With another.')
conn.insert_text('and another') conn.insert_text('and another re-init')
self.ae(conn.term_row_counts(), {'a': 1, 'and': 2, 'another': 2, 'period': 1, 'two': 1, 'with': 1, 'words': 1}) self.ae(conn.term_row_counts(), {'a': 1, 're': 1, 'init': 1, 'and': 2, 'another': 2, 'period': 1, 'two': 1, 'with': 1, 'words': 1})
conn = TestConn()
conn.insert_text('coộl')
self.ae(conn.term_row_counts(), {'cool': 1, 'coộl': 1})
# }}} # }}}