From 4d7f5b18d1c6fce3ec454b578caee2660174d5c9 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Fri, 18 Aug 2023 17:56:28 +0530 Subject: [PATCH] Implement FTS API for notes --- src/calibre/db/backend.py | 6 +++++ src/calibre/db/cache.py | 26 ++++++++++++++++++- src/calibre/db/notes/connect.py | 44 ++++++++++++++++++++++++++++++++- src/calibre/db/tests/notes.py | 21 +++++++++++++++- 4 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/calibre/db/backend.py b/src/calibre/db/backend.py index 5bba130b48..d1b5ba3704 100644 --- a/src/calibre/db/backend.py +++ b/src/calibre/db/backend.py @@ -994,6 +994,12 @@ class DB: def unretire_note(self, field, item_id, item_val): return self.notes.unretire(self.conn, field, item_id, item_val) + def notes_search(self, + fts_engine_query, use_stemming, highlight_start, highlight_end, snippet_size, restrict_to_fields, return_text, process_each_result + ): + yield from self.notes.search( + self.conn, fts_engine_query, use_stemming, highlight_start, highlight_end, snippet_size, restrict_to_fields, return_text, process_each_result) + def initialize_fts(self, dbref): self.fts = None if not self.prefs['fts_enabled']: diff --git a/src/calibre/db/cache.py b/src/calibre/db/cache.py index 8d46cf6f1c..38053b759c 100644 --- a/src/calibre/db/cache.py +++ b/src/calibre/db/cache.py @@ -647,7 +647,7 @@ class Cache: self._fts_start_measuring_rate() return changed - @write_api # we need to use write locking as SQLITE gives a locked table error is multiple FTS queries are made at the same time + @write_api # we need to use write locking as SQLITE gives a locked table error if multiple FTS queries are made at the same time def fts_search( self, fts_engine_query, @@ -697,6 +697,30 @@ class Cache: @write_api def unretire_note_for(self, field, item_id) -> int: return self.backend.unretire_note_for(field, item_id) + + @write_api # we need to use write locking as SQLITE gives a locked table error if multiple FTS queries are made at the same time + def notes_search( + self, + fts_engine_query, + use_stemming=True, + highlight_start=None, + highlight_end=None, + snippet_size=None, + restrict_to_fields=(), + return_text=True, + result_type=tuple, + process_each_result=None, + ): + return result_type(self.backend.notes_search( + fts_engine_query, + use_stemming=use_stemming, + highlight_start=highlight_start, + highlight_end=highlight_end, + snippet_size=snippet_size, + return_text=return_text, + restrict_to_fields=restrict_to_fields, + process_each_result=process_each_result, + )) # }}} # Cache Layer API {{{ diff --git a/src/calibre/db/notes/connect.py b/src/calibre/db/notes/connect.py index fddb2a7c6b..ba51550f3a 100644 --- a/src/calibre/db/notes/connect.py +++ b/src/calibre/db/notes/connect.py @@ -7,10 +7,12 @@ import shutil import time import xxhash from contextlib import suppress -from itertools import repeat +from itertools import count, repeat from typing import Optional, Union from calibre.constants import iswindows +from calibre.db import FTSQueryError +from calibre.db.annotations import unicode_normalize from calibre.utils.copy_files import WINDOWS_SLEEP_FOR_RETRY_TIME from calibre.utils.filenames import copyfile_using_links, make_long_path_useable from calibre.utils.icu import lower as icu_lower @@ -56,6 +58,7 @@ class Notes: def __init__(self, backend): conn = backend.get_connection() + self.temp_table_counter = count() libdir = os.path.dirname(os.path.abspath(conn.db_filename('main'))) notes_dir = os.path.join(libdir, NOTES_DIR_NAME) self.resources_dir = os.path.join(notes_dir, 'resources') @@ -293,3 +296,42 @@ class Notes: path = make_long_path_useable(path) with suppress(FileNotFoundError), open(path, 'rb') as f: return {'name': name, 'data': f.read(), 'hash': resource_hash} + + def search(self, + conn, fts_engine_query, use_stemming, highlight_start, highlight_end, snippet_size, restrict_to_fields=(), + return_text=True, process_each_result=None + ): + fts_engine_query = unicode_normalize(fts_engine_query) + fts_table = 'notes_fts' + ('_stemmed' if use_stemming else '') + if return_text: + text = 'notes.searchable_text' + if highlight_start is not None and highlight_end is not None: + if snippet_size is not None: + text = f'''snippet("{fts_table}", 0, '{highlight_start}', '{highlight_end}', '…', {max(1, min(snippet_size, 64))})''' + else: + text = f'''highlight("{fts_table}", 0, '{highlight_start}', '{highlight_end}')''' + text = ', ' + text + else: + text = '' + query = 'SELECT {0}.id, {0}.colname, {0}.item {1} FROM {0} '.format('notes', text) + query += f' JOIN {fts_table} ON notes_db.notes.id = {fts_table}.rowid' + query += ' WHERE ' + if restrict_to_fields: + query += ' notes_db.notes.colname IN ({}) AND '.format(','.join(repeat('?', len(restrict_to_fields)))) + query += f' "{fts_table}" MATCH ?' + query += f' ORDER BY {fts_table}.rank ' + try: + for record in conn.execute(query, restrict_to_fields+(fts_engine_query,)): + result = { + 'id': record[0], + 'field': record[1], + 'item_id': record[2], + 'text': record[3] if return_text else '', + } + if process_each_result is not None: + result = process_each_result(result) + ret = yield result + if ret is True: + break + except apsw.SQLError as e: + raise FTSQueryError(fts_engine_query, query, e) from e diff --git a/src/calibre/db/tests/notes.py b/src/calibre/db/tests/notes.py index 68be4dde47..bca74657f4 100644 --- a/src/calibre/db/tests/notes.py +++ b/src/calibre/db/tests/notes.py @@ -56,7 +56,7 @@ def test_notes_api(self: 'NotesTest'): self.ae(cache.get_notes_resource(h2)['data'], b'resource2') -def test_cache_api(self): +def test_cache_api(self: 'NotesTest'): cache, notes = self.create_notes_db() authors = cache.field_for('authors', 1) author_id = cache.get_item_id('authors', authors[0]) @@ -89,6 +89,24 @@ def test_cache_api(self): self.assertFalse(os.listdir(notes.retired_dir)) +def test_fts(self: 'NotesTest'): + cache, _ = self.create_notes_db() + authors = sorted(cache.all_field_ids('authors')) + cache.set_notes_for('authors', authors[0], 'Wunderbar wunderkind common') + cache.set_notes_for('authors', authors[1], 'Heavens to murgatroyd common') + tags = sorted(cache.all_field_ids('tags')) + cache.set_notes_for('tags', tags[0], 'Tag me baby, one more time common') + cache.set_notes_for('tags', tags[1], 'Jeepers, Batman! common') + + def ids_for_search(x, restrict_to_fields=()): + return { + (x['field'], x['item_id']) for x in cache.notes_search(x, restrict_to_fields=restrict_to_fields) + } + + self.ae(ids_for_search('wunderbar'), {('authors', authors[0])}) + self.ae(ids_for_search('common'), {('authors', authors[0]), ('authors', authors[1]), ('tags', tags[0]), ('tags', tags[1])}) + self.ae(ids_for_search('common', ('tags',)), {('tags', tags[0]), ('tags', tags[1])}) + class NotesTest(BaseTest): ae = BaseTest.assertEqual @@ -99,5 +117,6 @@ class NotesTest(BaseTest): return cache, cache.backend.notes def test_notes(self): + test_fts(self) test_cache_api(self) test_notes_api(self)