Use a full dbref rather than just get_connection

This commit is contained in:
Kovid Goyal 2022-02-16 12:02:13 +05:30
parent 58bde2e304
commit 55c67d57e4
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 34 additions and 11 deletions

View File

@ -337,6 +337,7 @@ class Connection(apsw.Connection): # {{{
set_ui_language(get_lang())
super().__init__(path)
plugins.load_apsw_extension(self, 'sqlite_extension')
self.fts_dbpath = None
self.setbusytimeout(self.BUSY_TIMEOUT)
self.execute('pragma cache_size=-5000')
@ -492,7 +493,6 @@ class DB:
self.initialize_prefs(default_prefs, restore_all_prefs, progress_callback)
self.initialize_custom_columns()
self.initialize_tables()
self.initialize_fts()
self.set_user_template_functions(compile_user_template_functions(
self.prefs.get('user_template_functions', [])))
if load_user_formatter_functions:
@ -921,17 +921,18 @@ class DB:
# }}}
def initialize_fts(self):
def initialize_fts(self, dbref):
self.fts = None
if not self.prefs['fts_enabled']:
return
from .fts.connect import FTS
self.fts = FTS(self.get_connection)
self.fts = FTS(dbref)
def enable_fts(self, enabled=True):
def enable_fts(self, dbref=None):
enabled = dbref is not None
if enabled != self.prefs['fts_enabled']:
self.prefs['fts_enabled'] = enabled
self.initialize_fts()
self.initialize_fts(dbref)
if self.fts is not None:
self.fts.dirty_existing()
return self.fts

View File

@ -11,6 +11,7 @@ import random
import shutil
import sys
import traceback
import weakref
from collections import defaultdict
from collections.abc import MutableSet, Set
from functools import partial, wraps
@ -168,6 +169,7 @@ class Cache:
self._search_api = Search(self, 'saved_searches', self.field_metadata.get_search_terms())
self.initialize_dynamic()
self.initialize_fts()
@property
def new_api(self):
@ -421,6 +423,14 @@ class Cache:
self.update_last_modified(self.all_book_ids())
self.backend.prefs.set('update_all_last_mod_dates_on_start', False)
# FTS API {{{
def initialize_fts(self):
self.backend.initialize_fts(weakref.ref(self))
def enable_fts(self, enabled=True):
return self.backend.enable_fts(weakref.ref(self) if enabled else None)
# }}}
# Cache Layer API {{{
@write_api

View File

@ -11,6 +11,7 @@ from calibre.utils.date import EPOCH, utcnow
from .schema_upgrade import SchemaUpgrade
# TODO: check that closing of db connection works
# TODO: db dump+restore
# TODO: calibre export/import
# TODO: check library and vacuuming of fts db
@ -23,13 +24,24 @@ def print(*args, **kwargs):
class FTS:
def __init__(self, get_connection):
self.get_connection = get_connection
conn = self.get_connection()
def __init__(self, dbref):
self.dbref = dbref
def initialize(self, conn):
main_db_path = os.path.abspath(conn.db_filename('main'))
self.dbpath = os.path.join(os.path.dirname(main_db_path), 'full-text-search.db')
conn.execute(f'ATTACH DATABASE "{self.dbpath}" AS fts_db')
dbpath = os.path.join(os.path.dirname(main_db_path), 'full-text-search.db')
conn.execute(f'ATTACH DATABASE "{dbpath}" AS fts_db')
SchemaUpgrade(conn)
conn.fts_dbpath = dbpath
def get_connection(self):
db = self.dbref()
if db is None:
raise RuntimeError('db has been garbage collected')
ans = db.backend.conn
if ans.fts_dbpath is None:
self.initialize(ans)
return ans
def dirty_existing(self):
conn = self.get_connection()

View File

@ -30,7 +30,7 @@ class FTSAPITest(BaseTest):
def test_fts_triggers(self):
cache = self.init_cache()
fts = cache.backend.enable_fts()
fts = cache.enable_fts()
self.ae(fts.all_currently_dirty(), [(1, 'FMT1'), (1, 'FMT2'), (2, 'FMT1')])
fts.dirty_existing()
self.ae(fts.all_currently_dirty(), [(1, 'FMT1'), (1, 'FMT2'), (2, 'FMT1')])