Implement FTS API for notes

This commit is contained in:
Kovid Goyal 2023-08-18 17:56:28 +05:30
parent a0b9a799d9
commit 4d7f5b18d1
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 94 additions and 3 deletions

View File

@ -994,6 +994,12 @@ class DB:
def unretire_note(self, field, item_id, item_val): def unretire_note(self, field, item_id, item_val):
return self.notes.unretire(self.conn, field, item_id, item_val) return self.notes.unretire(self.conn, field, item_id, item_val)
def notes_search(self,
fts_engine_query, use_stemming, highlight_start, highlight_end, snippet_size, restrict_to_fields, return_text, process_each_result
):
yield from self.notes.search(
self.conn, fts_engine_query, use_stemming, highlight_start, highlight_end, snippet_size, restrict_to_fields, return_text, process_each_result)
def initialize_fts(self, dbref): def initialize_fts(self, dbref):
self.fts = None self.fts = None
if not self.prefs['fts_enabled']: if not self.prefs['fts_enabled']:

View File

@ -647,7 +647,7 @@ class Cache:
self._fts_start_measuring_rate() self._fts_start_measuring_rate()
return changed return changed
@write_api # we need to use write locking as SQLITE gives a locked table error is multiple FTS queries are made at the same time @write_api # we need to use write locking as SQLITE gives a locked table error if multiple FTS queries are made at the same time
def fts_search( def fts_search(
self, self,
fts_engine_query, fts_engine_query,
@ -697,6 +697,30 @@ class Cache:
@write_api @write_api
def unretire_note_for(self, field, item_id) -> int: def unretire_note_for(self, field, item_id) -> int:
return self.backend.unretire_note_for(field, item_id) return self.backend.unretire_note_for(field, item_id)
@write_api # we need to use write locking as SQLITE gives a locked table error if multiple FTS queries are made at the same time
def notes_search(
self,
fts_engine_query,
use_stemming=True,
highlight_start=None,
highlight_end=None,
snippet_size=None,
restrict_to_fields=(),
return_text=True,
result_type=tuple,
process_each_result=None,
):
return result_type(self.backend.notes_search(
fts_engine_query,
use_stemming=use_stemming,
highlight_start=highlight_start,
highlight_end=highlight_end,
snippet_size=snippet_size,
return_text=return_text,
restrict_to_fields=restrict_to_fields,
process_each_result=process_each_result,
))
# }}} # }}}
# Cache Layer API {{{ # Cache Layer API {{{

View File

@ -7,10 +7,12 @@ import shutil
import time import time
import xxhash import xxhash
from contextlib import suppress from contextlib import suppress
from itertools import repeat from itertools import count, repeat
from typing import Optional, Union from typing import Optional, Union
from calibre.constants import iswindows from calibre.constants import iswindows
from calibre.db import FTSQueryError
from calibre.db.annotations import unicode_normalize
from calibre.utils.copy_files import WINDOWS_SLEEP_FOR_RETRY_TIME from calibre.utils.copy_files import WINDOWS_SLEEP_FOR_RETRY_TIME
from calibre.utils.filenames import copyfile_using_links, make_long_path_useable from calibre.utils.filenames import copyfile_using_links, make_long_path_useable
from calibre.utils.icu import lower as icu_lower from calibre.utils.icu import lower as icu_lower
@ -56,6 +58,7 @@ class Notes:
def __init__(self, backend): def __init__(self, backend):
conn = backend.get_connection() conn = backend.get_connection()
self.temp_table_counter = count()
libdir = os.path.dirname(os.path.abspath(conn.db_filename('main'))) libdir = os.path.dirname(os.path.abspath(conn.db_filename('main')))
notes_dir = os.path.join(libdir, NOTES_DIR_NAME) notes_dir = os.path.join(libdir, NOTES_DIR_NAME)
self.resources_dir = os.path.join(notes_dir, 'resources') self.resources_dir = os.path.join(notes_dir, 'resources')
@ -293,3 +296,42 @@ class Notes:
path = make_long_path_useable(path) path = make_long_path_useable(path)
with suppress(FileNotFoundError), open(path, 'rb') as f: with suppress(FileNotFoundError), open(path, 'rb') as f:
return {'name': name, 'data': f.read(), 'hash': resource_hash} return {'name': name, 'data': f.read(), 'hash': resource_hash}
def search(self,
conn, fts_engine_query, use_stemming, highlight_start, highlight_end, snippet_size, restrict_to_fields=(),
return_text=True, process_each_result=None
):
fts_engine_query = unicode_normalize(fts_engine_query)
fts_table = 'notes_fts' + ('_stemmed' if use_stemming else '')
if return_text:
text = 'notes.searchable_text'
if highlight_start is not None and highlight_end is not None:
if snippet_size is not None:
text = f'''snippet("{fts_table}", 0, '{highlight_start}', '{highlight_end}', '', {max(1, min(snippet_size, 64))})'''
else:
text = f'''highlight("{fts_table}", 0, '{highlight_start}', '{highlight_end}')'''
text = ', ' + text
else:
text = ''
query = 'SELECT {0}.id, {0}.colname, {0}.item {1} FROM {0} '.format('notes', text)
query += f' JOIN {fts_table} ON notes_db.notes.id = {fts_table}.rowid'
query += ' WHERE '
if restrict_to_fields:
query += ' notes_db.notes.colname IN ({}) AND '.format(','.join(repeat('?', len(restrict_to_fields))))
query += f' "{fts_table}" MATCH ?'
query += f' ORDER BY {fts_table}.rank '
try:
for record in conn.execute(query, restrict_to_fields+(fts_engine_query,)):
result = {
'id': record[0],
'field': record[1],
'item_id': record[2],
'text': record[3] if return_text else '',
}
if process_each_result is not None:
result = process_each_result(result)
ret = yield result
if ret is True:
break
except apsw.SQLError as e:
raise FTSQueryError(fts_engine_query, query, e) from e

View File

@ -56,7 +56,7 @@ def test_notes_api(self: 'NotesTest'):
self.ae(cache.get_notes_resource(h2)['data'], b'resource2') self.ae(cache.get_notes_resource(h2)['data'], b'resource2')
def test_cache_api(self): def test_cache_api(self: 'NotesTest'):
cache, notes = self.create_notes_db() cache, notes = self.create_notes_db()
authors = cache.field_for('authors', 1) authors = cache.field_for('authors', 1)
author_id = cache.get_item_id('authors', authors[0]) author_id = cache.get_item_id('authors', authors[0])
@ -89,6 +89,24 @@ def test_cache_api(self):
self.assertFalse(os.listdir(notes.retired_dir)) self.assertFalse(os.listdir(notes.retired_dir))
def test_fts(self: 'NotesTest'):
cache, _ = self.create_notes_db()
authors = sorted(cache.all_field_ids('authors'))
cache.set_notes_for('authors', authors[0], 'Wunderbar wunderkind common')
cache.set_notes_for('authors', authors[1], 'Heavens to murgatroyd common')
tags = sorted(cache.all_field_ids('tags'))
cache.set_notes_for('tags', tags[0], 'Tag me baby, one more time common')
cache.set_notes_for('tags', tags[1], 'Jeepers, Batman! common')
def ids_for_search(x, restrict_to_fields=()):
return {
(x['field'], x['item_id']) for x in cache.notes_search(x, restrict_to_fields=restrict_to_fields)
}
self.ae(ids_for_search('wunderbar'), {('authors', authors[0])})
self.ae(ids_for_search('common'), {('authors', authors[0]), ('authors', authors[1]), ('tags', tags[0]), ('tags', tags[1])})
self.ae(ids_for_search('common', ('tags',)), {('tags', tags[0]), ('tags', tags[1])})
class NotesTest(BaseTest): class NotesTest(BaseTest):
ae = BaseTest.assertEqual ae = BaseTest.assertEqual
@ -99,5 +117,6 @@ class NotesTest(BaseTest):
return cache, cache.backend.notes return cache, cache.backend.notes
def test_notes(self): def test_notes(self):
test_fts(self)
test_cache_api(self) test_cache_api(self)
test_notes_api(self) test_notes_api(self)