Use a full dbref rather than just get_connection

This commit is contained in:
Kovid Goyal 2022-02-16 12:02:13 +05:30
parent 58bde2e304
commit 55c67d57e4
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 34 additions and 11 deletions

View File

@ -337,6 +337,7 @@ class Connection(apsw.Connection): # {{{
set_ui_language(get_lang()) set_ui_language(get_lang())
super().__init__(path) super().__init__(path)
plugins.load_apsw_extension(self, 'sqlite_extension') plugins.load_apsw_extension(self, 'sqlite_extension')
self.fts_dbpath = None
self.setbusytimeout(self.BUSY_TIMEOUT) self.setbusytimeout(self.BUSY_TIMEOUT)
self.execute('pragma cache_size=-5000') self.execute('pragma cache_size=-5000')
@ -492,7 +493,6 @@ class DB:
self.initialize_prefs(default_prefs, restore_all_prefs, progress_callback) self.initialize_prefs(default_prefs, restore_all_prefs, progress_callback)
self.initialize_custom_columns() self.initialize_custom_columns()
self.initialize_tables() self.initialize_tables()
self.initialize_fts()
self.set_user_template_functions(compile_user_template_functions( self.set_user_template_functions(compile_user_template_functions(
self.prefs.get('user_template_functions', []))) self.prefs.get('user_template_functions', [])))
if load_user_formatter_functions: if load_user_formatter_functions:
@ -921,17 +921,18 @@ class DB:
# }}} # }}}
def initialize_fts(self): def initialize_fts(self, dbref):
self.fts = None self.fts = None
if not self.prefs['fts_enabled']: if not self.prefs['fts_enabled']:
return return
from .fts.connect import FTS from .fts.connect import FTS
self.fts = FTS(self.get_connection) self.fts = FTS(dbref)
def enable_fts(self, enabled=True): def enable_fts(self, dbref=None):
enabled = dbref is not None
if enabled != self.prefs['fts_enabled']: if enabled != self.prefs['fts_enabled']:
self.prefs['fts_enabled'] = enabled self.prefs['fts_enabled'] = enabled
self.initialize_fts() self.initialize_fts(dbref)
if self.fts is not None: if self.fts is not None:
self.fts.dirty_existing() self.fts.dirty_existing()
return self.fts return self.fts

View File

@ -11,6 +11,7 @@ import random
import shutil import shutil
import sys import sys
import traceback import traceback
import weakref
from collections import defaultdict from collections import defaultdict
from collections.abc import MutableSet, Set from collections.abc import MutableSet, Set
from functools import partial, wraps from functools import partial, wraps
@ -168,6 +169,7 @@ class Cache:
self._search_api = Search(self, 'saved_searches', self.field_metadata.get_search_terms()) self._search_api = Search(self, 'saved_searches', self.field_metadata.get_search_terms())
self.initialize_dynamic() self.initialize_dynamic()
self.initialize_fts()
@property @property
def new_api(self): def new_api(self):
@ -421,6 +423,14 @@ class Cache:
self.update_last_modified(self.all_book_ids()) self.update_last_modified(self.all_book_ids())
self.backend.prefs.set('update_all_last_mod_dates_on_start', False) self.backend.prefs.set('update_all_last_mod_dates_on_start', False)
# FTS API {{{
def initialize_fts(self):
self.backend.initialize_fts(weakref.ref(self))
def enable_fts(self, enabled=True):
return self.backend.enable_fts(weakref.ref(self) if enabled else None)
# }}}
# Cache Layer API {{{ # Cache Layer API {{{
@write_api @write_api

View File

@ -11,6 +11,7 @@ from calibre.utils.date import EPOCH, utcnow
from .schema_upgrade import SchemaUpgrade from .schema_upgrade import SchemaUpgrade
# TODO: check that closing of db connection works
# TODO: db dump+restore # TODO: db dump+restore
# TODO: calibre export/import # TODO: calibre export/import
# TODO: check library and vacuuming of fts db # TODO: check library and vacuuming of fts db
@ -23,13 +24,24 @@ def print(*args, **kwargs):
class FTS: class FTS:
def __init__(self, get_connection): def __init__(self, dbref):
self.get_connection = get_connection self.dbref = dbref
conn = self.get_connection()
def initialize(self, conn):
main_db_path = os.path.abspath(conn.db_filename('main')) main_db_path = os.path.abspath(conn.db_filename('main'))
self.dbpath = os.path.join(os.path.dirname(main_db_path), 'full-text-search.db') dbpath = os.path.join(os.path.dirname(main_db_path), 'full-text-search.db')
conn.execute(f'ATTACH DATABASE "{self.dbpath}" AS fts_db') conn.execute(f'ATTACH DATABASE "{dbpath}" AS fts_db')
SchemaUpgrade(conn) SchemaUpgrade(conn)
conn.fts_dbpath = dbpath
def get_connection(self):
db = self.dbref()
if db is None:
raise RuntimeError('db has been garbage collected')
ans = db.backend.conn
if ans.fts_dbpath is None:
self.initialize(ans)
return ans
def dirty_existing(self): def dirty_existing(self):
conn = self.get_connection() conn = self.get_connection()

View File

@ -30,7 +30,7 @@ class FTSAPITest(BaseTest):
def test_fts_triggers(self): def test_fts_triggers(self):
cache = self.init_cache() cache = self.init_cache()
fts = cache.backend.enable_fts() fts = cache.enable_fts()
self.ae(fts.all_currently_dirty(), [(1, 'FMT1'), (1, 'FMT2'), (2, 'FMT1')]) self.ae(fts.all_currently_dirty(), [(1, 'FMT1'), (1, 'FMT2'), (2, 'FMT1')])
fts.dirty_existing() fts.dirty_existing()
self.ae(fts.all_currently_dirty(), [(1, 'FMT1'), (1, 'FMT2'), (2, 'FMT1')]) self.ae(fts.all_currently_dirty(), [(1, 'FMT1'), (1, 'FMT2'), (2, 'FMT1')])