mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Add tests for stemming
This commit is contained in:
parent
2bfc3d1e7f
commit
7fe5fff311
@ -141,6 +141,7 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
typedef std::unique_ptr<icu::BreakIterator> BreakIterator;
|
typedef std::unique_ptr<icu::BreakIterator> BreakIterator;
|
||||||
|
typedef std::unique_ptr<Stemmer> StemmerPtr;
|
||||||
|
|
||||||
class Tokenizer {
|
class Tokenizer {
|
||||||
private:
|
private:
|
||||||
@ -151,7 +152,7 @@ private:
|
|||||||
token_callback_func current_callback;
|
token_callback_func current_callback;
|
||||||
void *current_callback_ctx;
|
void *current_callback_ctx;
|
||||||
std::map<const char*, BreakIterator, char_cmp> iterators;
|
std::map<const char*, BreakIterator, char_cmp> iterators;
|
||||||
std::map<const char*, Stemmer, char_cmp> stemmers;
|
std::map<const char*, StemmerPtr, char_cmp> stemmers;
|
||||||
|
|
||||||
bool is_token_char(UChar32 ch) const {
|
bool is_token_char(UChar32 ch) const {
|
||||||
switch(u_charType(ch)) {
|
switch(u_charType(ch)) {
|
||||||
@ -173,17 +174,18 @@ private:
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, Stemmer &stemmer, int flags = 0) {
|
int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, StemmerPtr &stemmer, int flags = 0) {
|
||||||
token_buf.clear(); token_buf.reserve(4 * token.length());
|
token_buf.clear(); token_buf.reserve(4 * token.length());
|
||||||
token.toUTF8String(token_buf);
|
token.toUTF8String(token_buf);
|
||||||
const char *root = token_buf.c_str(); int sz = (int)token_buf.size();
|
const char *root = token_buf.c_str(); int sz = (int)token_buf.size();
|
||||||
if (stem_words && stemmer) {
|
if (stem_words && stemmer->operator bool()) {
|
||||||
root = stemmer.stem(root, sz, &sz);
|
root = stemmer->stem(root, sz, &sz);
|
||||||
if (!root) {
|
if (!root) {
|
||||||
root = token_buf.c_str(); sz = (int)token_buf.size();
|
root = token_buf.c_str();
|
||||||
|
sz = (int)token_buf.size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return current_callback(current_callback_ctx, flags, token_buf.c_str(), (int)token_buf.size(), byte_offsets.at(start_offset), byte_offsets.at(end_offset));
|
return current_callback(current_callback_ctx, flags, root, (int)sz, byte_offsets.at(start_offset), byte_offsets.at(end_offset));
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* iterator_language_for_script(UScriptCode script) const {
|
const char* iterator_language_for_script(UScriptCode script) const {
|
||||||
@ -247,14 +249,18 @@ private:
|
|||||||
return ans->second;
|
return ans->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
Stemmer& ensure_stemmer(const char *lang = "") {
|
StemmerPtr& ensure_stemmer(const char *lang = "") {
|
||||||
if (!lang[0]) lang = current_ui_language.c_str();
|
if (!lang[0]) lang = current_ui_language.c_str();
|
||||||
auto ans = stemmers.find(lang);
|
auto ans = stemmers.find(lang);
|
||||||
if (ans == stemmers.end()) stemmers[lang] = stem_words ? Stemmer(lang) : Stemmer();
|
if (ans == stemmers.end()) {
|
||||||
|
if (stem_words) stemmers[lang] = StemmerPtr(new Stemmer(lang));
|
||||||
|
else stemmers[lang] = StemmerPtr(new Stemmer());
|
||||||
|
ans = stemmers.find(lang);
|
||||||
|
}
|
||||||
return ans->second;
|
return ans->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, Stemmer &stemmer) {
|
int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, StemmerPtr &stemmer) {
|
||||||
word_iterator->setText(str.tempSubStringBetween(block_start, block_limit));
|
word_iterator->setText(str.tempSubStringBetween(block_start, block_limit));
|
||||||
int32_t token_start_pos = word_iterator->first() + block_start, token_end_pos;
|
int32_t token_start_pos = word_iterator->first() + block_start, token_end_pos;
|
||||||
int rc = SQLITE_OK;
|
int rc = SQLITE_OK;
|
||||||
@ -493,6 +499,21 @@ tokenize(PyObject *self, PyObject *args) {
|
|||||||
return ans.detach();
|
return ans.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject*
|
||||||
|
stem(PyObject *self, PyObject *args) {
|
||||||
|
const char *text, *lang = "en"; int text_length;
|
||||||
|
if (!PyArg_ParseTuple(args, "s#|s", &text, &text_length, &lang)) return NULL;
|
||||||
|
Stemmer s(lang);
|
||||||
|
if (!s) {
|
||||||
|
PyErr_SetString(PyExc_ValueError, "No stemmer for the specified language");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
int sz;
|
||||||
|
const char* result = s.stem(text, text_length, &sz);
|
||||||
|
if (!result) return PyErr_NoMemory();
|
||||||
|
return Py_BuildValue("s#", result, sz);
|
||||||
|
}
|
||||||
|
|
||||||
static PyMethodDef methods[] = {
|
static PyMethodDef methods[] = {
|
||||||
{"get_locales_for_break_iteration", get_locales_for_break_iteration, METH_NOARGS,
|
{"get_locales_for_break_iteration", get_locales_for_break_iteration, METH_NOARGS,
|
||||||
"Get list of available locales for break iteration"
|
"Get list of available locales for break iteration"
|
||||||
@ -503,6 +524,12 @@ static PyMethodDef methods[] = {
|
|||||||
{"tokenize", tokenize, METH_VARARGS,
|
{"tokenize", tokenize, METH_VARARGS,
|
||||||
"Tokenize a string, useful for testing"
|
"Tokenize a string, useful for testing"
|
||||||
},
|
},
|
||||||
|
{"tokenize", tokenize, METH_VARARGS,
|
||||||
|
"Tokenize a string, useful for testing"
|
||||||
|
},
|
||||||
|
{"stem", stem, METH_VARARGS,
|
||||||
|
"Stem a word in the specified language, defaulting to English"
|
||||||
|
},
|
||||||
{NULL, NULL, 0, NULL}
|
{NULL, NULL, 0, NULL}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ def print(*args, **kwargs):
|
|||||||
|
|
||||||
class TestConn(Connection):
|
class TestConn(Connection):
|
||||||
|
|
||||||
def __init__(self, remove_diacritics=True, language='en'):
|
def __init__(self, remove_diacritics=True, language='en', stem_words=False):
|
||||||
from calibre_extensions.sqlite_extension import set_ui_language
|
from calibre_extensions.sqlite_extension import set_ui_language
|
||||||
set_ui_language(language)
|
set_ui_language(language)
|
||||||
super().__init__(':memory:')
|
super().__init__(':memory:')
|
||||||
@ -27,8 +27,9 @@ class TestConn(Connection):
|
|||||||
options = []
|
options = []
|
||||||
options.append('remove_diacritics'), options.append('2' if remove_diacritics else '0')
|
options.append('remove_diacritics'), options.append('2' if remove_diacritics else '0')
|
||||||
options = ' '.join(options)
|
options = ' '.join(options)
|
||||||
|
tok = 'porter ' if stem_words else ''
|
||||||
self.execute(f'''
|
self.execute(f'''
|
||||||
CREATE VIRTUAL TABLE fts_table USING fts5(t, tokenize = 'unicode61 {options}');
|
CREATE VIRTUAL TABLE fts_table USING fts5(t, tokenize = '{tok}unicode61 {options}');
|
||||||
CREATE VIRTUAL TABLE fts_row USING fts5vocab(fts_table, row);
|
CREATE VIRTUAL TABLE fts_row USING fts5vocab(fts_table, row);
|
||||||
''')
|
''')
|
||||||
|
|
||||||
@ -139,6 +140,26 @@ class FTSTest(BaseTest):
|
|||||||
self.ae(conn.search("叫"), [("你don't>叫<mess",)])
|
self.ae(conn.search("叫"), [("你don't>叫<mess",)])
|
||||||
# }}}
|
# }}}
|
||||||
|
|
||||||
|
def test_fts_stemming(self): # {{{
|
||||||
|
from calibre_extensions.sqlite_extension import stem
|
||||||
|
|
||||||
|
self.ae(stem('run'), 'run')
|
||||||
|
self.ae(stem('connection'), 'connect')
|
||||||
|
self.ae(stem('maintenaient'), 'maintenai')
|
||||||
|
self.ae(stem('maintenaient', 'fr'), 'mainten')
|
||||||
|
self.ae(stem('continué', 'fr'), 'continu')
|
||||||
|
self.ae(stem('maître', 'FRA'), 'maîtr')
|
||||||
|
|
||||||
|
conn = TestConn(stem_words=True)
|
||||||
|
conn.insert_text('a simplistic connection')
|
||||||
|
self.ae(conn.term_row_counts(), {'a': 1, 'connect': 1, 'simplist': 1})
|
||||||
|
self.ae(conn.search("connection"), [('a simplistic >connection<',),])
|
||||||
|
self.ae(conn.search("connect"), [('a simplistic >connection<',),])
|
||||||
|
self.ae(conn.search("simplistic connect"), [('a >simplistic< >connection<',),])
|
||||||
|
self.ae(conn.search("simplist"), [('a >simplistic< connection',),])
|
||||||
|
|
||||||
|
# }}}
|
||||||
|
|
||||||
def test_fts_query_syntax(self): # {{{
|
def test_fts_query_syntax(self): # {{{
|
||||||
conn = TestConn()
|
conn = TestConn()
|
||||||
conn.insert_text('one two three')
|
conn.insert_text('one two three')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user