mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Add tests for stemming
This commit is contained in:
parent
2bfc3d1e7f
commit
7fe5fff311
@ -141,6 +141,7 @@ public:
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<icu::BreakIterator> BreakIterator;
|
||||
typedef std::unique_ptr<Stemmer> StemmerPtr;
|
||||
|
||||
class Tokenizer {
|
||||
private:
|
||||
@ -151,7 +152,7 @@ private:
|
||||
token_callback_func current_callback;
|
||||
void *current_callback_ctx;
|
||||
std::map<const char*, BreakIterator, char_cmp> iterators;
|
||||
std::map<const char*, Stemmer, char_cmp> stemmers;
|
||||
std::map<const char*, StemmerPtr, char_cmp> stemmers;
|
||||
|
||||
bool is_token_char(UChar32 ch) const {
|
||||
switch(u_charType(ch)) {
|
||||
@ -173,17 +174,18 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, Stemmer &stemmer, int flags = 0) {
|
||||
int send_token(const icu::UnicodeString &token, int32_t start_offset, int32_t end_offset, StemmerPtr &stemmer, int flags = 0) {
|
||||
token_buf.clear(); token_buf.reserve(4 * token.length());
|
||||
token.toUTF8String(token_buf);
|
||||
const char *root = token_buf.c_str(); int sz = (int)token_buf.size();
|
||||
if (stem_words && stemmer) {
|
||||
root = stemmer.stem(root, sz, &sz);
|
||||
if (stem_words && stemmer->operator bool()) {
|
||||
root = stemmer->stem(root, sz, &sz);
|
||||
if (!root) {
|
||||
root = token_buf.c_str(); sz = (int)token_buf.size();
|
||||
root = token_buf.c_str();
|
||||
sz = (int)token_buf.size();
|
||||
}
|
||||
}
|
||||
return current_callback(current_callback_ctx, flags, token_buf.c_str(), (int)token_buf.size(), byte_offsets.at(start_offset), byte_offsets.at(end_offset));
|
||||
return current_callback(current_callback_ctx, flags, root, (int)sz, byte_offsets.at(start_offset), byte_offsets.at(end_offset));
|
||||
}
|
||||
|
||||
const char* iterator_language_for_script(UScriptCode script) const {
|
||||
@ -247,14 +249,18 @@ private:
|
||||
return ans->second;
|
||||
}
|
||||
|
||||
Stemmer& ensure_stemmer(const char *lang = "") {
|
||||
StemmerPtr& ensure_stemmer(const char *lang = "") {
|
||||
if (!lang[0]) lang = current_ui_language.c_str();
|
||||
auto ans = stemmers.find(lang);
|
||||
if (ans == stemmers.end()) stemmers[lang] = stem_words ? Stemmer(lang) : Stemmer();
|
||||
if (ans == stemmers.end()) {
|
||||
if (stem_words) stemmers[lang] = StemmerPtr(new Stemmer(lang));
|
||||
else stemmers[lang] = StemmerPtr(new Stemmer());
|
||||
ans = stemmers.find(lang);
|
||||
}
|
||||
return ans->second;
|
||||
}
|
||||
|
||||
int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, Stemmer &stemmer) {
|
||||
int tokenize_script_block(const icu::UnicodeString &str, int32_t block_start, int32_t block_limit, bool for_query, token_callback_func callback, void *callback_ctx, BreakIterator &word_iterator, StemmerPtr &stemmer) {
|
||||
word_iterator->setText(str.tempSubStringBetween(block_start, block_limit));
|
||||
int32_t token_start_pos = word_iterator->first() + block_start, token_end_pos;
|
||||
int rc = SQLITE_OK;
|
||||
@ -493,6 +499,21 @@ tokenize(PyObject *self, PyObject *args) {
|
||||
return ans.detach();
|
||||
}
|
||||
|
||||
static PyObject*
|
||||
stem(PyObject *self, PyObject *args) {
|
||||
const char *text, *lang = "en"; int text_length;
|
||||
if (!PyArg_ParseTuple(args, "s#|s", &text, &text_length, &lang)) return NULL;
|
||||
Stemmer s(lang);
|
||||
if (!s) {
|
||||
PyErr_SetString(PyExc_ValueError, "No stemmer for the specified language");
|
||||
return NULL;
|
||||
}
|
||||
int sz;
|
||||
const char* result = s.stem(text, text_length, &sz);
|
||||
if (!result) return PyErr_NoMemory();
|
||||
return Py_BuildValue("s#", result, sz);
|
||||
}
|
||||
|
||||
static PyMethodDef methods[] = {
|
||||
{"get_locales_for_break_iteration", get_locales_for_break_iteration, METH_NOARGS,
|
||||
"Get list of available locales for break iteration"
|
||||
@ -503,6 +524,12 @@ static PyMethodDef methods[] = {
|
||||
{"tokenize", tokenize, METH_VARARGS,
|
||||
"Tokenize a string, useful for testing"
|
||||
},
|
||||
{"tokenize", tokenize, METH_VARARGS,
|
||||
"Tokenize a string, useful for testing"
|
||||
},
|
||||
{"stem", stem, METH_VARARGS,
|
||||
"Stem a word in the specified language, defaulting to English"
|
||||
},
|
||||
{NULL, NULL, 0, NULL}
|
||||
};
|
||||
|
||||
|
@ -19,7 +19,7 @@ def print(*args, **kwargs):
|
||||
|
||||
class TestConn(Connection):
|
||||
|
||||
def __init__(self, remove_diacritics=True, language='en'):
|
||||
def __init__(self, remove_diacritics=True, language='en', stem_words=False):
|
||||
from calibre_extensions.sqlite_extension import set_ui_language
|
||||
set_ui_language(language)
|
||||
super().__init__(':memory:')
|
||||
@ -27,8 +27,9 @@ class TestConn(Connection):
|
||||
options = []
|
||||
options.append('remove_diacritics'), options.append('2' if remove_diacritics else '0')
|
||||
options = ' '.join(options)
|
||||
tok = 'porter ' if stem_words else ''
|
||||
self.execute(f'''
|
||||
CREATE VIRTUAL TABLE fts_table USING fts5(t, tokenize = 'unicode61 {options}');
|
||||
CREATE VIRTUAL TABLE fts_table USING fts5(t, tokenize = '{tok}unicode61 {options}');
|
||||
CREATE VIRTUAL TABLE fts_row USING fts5vocab(fts_table, row);
|
||||
''')
|
||||
|
||||
@ -139,6 +140,26 @@ class FTSTest(BaseTest):
|
||||
self.ae(conn.search("叫"), [("你don't>叫<mess",)])
|
||||
# }}}
|
||||
|
||||
def test_fts_stemming(self): # {{{
|
||||
from calibre_extensions.sqlite_extension import stem
|
||||
|
||||
self.ae(stem('run'), 'run')
|
||||
self.ae(stem('connection'), 'connect')
|
||||
self.ae(stem('maintenaient'), 'maintenai')
|
||||
self.ae(stem('maintenaient', 'fr'), 'mainten')
|
||||
self.ae(stem('continué', 'fr'), 'continu')
|
||||
self.ae(stem('maître', 'FRA'), 'maîtr')
|
||||
|
||||
conn = TestConn(stem_words=True)
|
||||
conn.insert_text('a simplistic connection')
|
||||
self.ae(conn.term_row_counts(), {'a': 1, 'connect': 1, 'simplist': 1})
|
||||
self.ae(conn.search("connection"), [('a simplistic >connection<',),])
|
||||
self.ae(conn.search("connect"), [('a simplistic >connection<',),])
|
||||
self.ae(conn.search("simplistic connect"), [('a >simplistic< >connection<',),])
|
||||
self.ae(conn.search("simplist"), [('a >simplistic< connection',),])
|
||||
|
||||
# }}}
|
||||
|
||||
def test_fts_query_syntax(self): # {{{
|
||||
conn = TestConn()
|
||||
conn.insert_text('one two three')
|
||||
|
Loading…
x
Reference in New Issue
Block a user