Ensure initialization is thread safe

This commit is contained in:
Kovid Goyal 2022-05-02 08:11:06 +05:30
parent 958625f660
commit 368bba0ac3
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -3,15 +3,17 @@
# License: GPL v3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net> # License: GPL v3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import builtins, apsw import apsw
import builtins
import hashlib import hashlib
import os import os
import sys import sys
from contextlib import suppress from contextlib import suppress
from threading import Lock
from calibre.utils.date import EPOCH, utcnow
from calibre.db import FTSQueryError from calibre.db import FTSQueryError
from calibre.db.annotations import unicode_normalize from calibre.db.annotations import unicode_normalize
from calibre.utils.date import EPOCH, utcnow
from .pool import Pool from .pool import Pool
from .schema_upgrade import SchemaUpgrade from .schema_upgrade import SchemaUpgrade
@ -32,27 +34,32 @@ class FTS:
def __init__(self, dbref): def __init__(self, dbref):
self.dbref = dbref self.dbref = dbref
self.pool = Pool(dbref) self.pool = Pool(dbref)
self.init_lock = Lock()
def initialize(self, conn): def initialize(self, conn):
main_db_path = os.path.abspath(conn.db_filename('main')) needs_dirty = False
dbpath = os.path.join(os.path.dirname(main_db_path), 'full-text-search.db') with self.init_lock:
conn.execute(f'ATTACH DATABASE "{dbpath}" AS fts_db') if conn.fts_dbpath is None:
SchemaUpgrade(conn) main_db_path = os.path.abspath(conn.db_filename('main'))
conn.fts_dbpath = dbpath dbpath = os.path.join(os.path.dirname(main_db_path), 'full-text-search.db')
conn.execute('UPDATE fts_db.dirtied_formats SET in_progress=FALSE WHERE in_progress=TRUE') conn.execute(f'ATTACH DATABASE "{dbpath}" AS fts_db')
num_dirty = conn.get('''SELECT COUNT(*) from fts_db.dirtied_formats''')[0][0] SchemaUpgrade(conn)
if not num_dirty: conn.execute('UPDATE fts_db.dirtied_formats SET in_progress=FALSE WHERE in_progress=TRUE')
num_indexed = conn.get('''SELECT COUNT(*) from fts_db.books_text''')[0][0] num_dirty = conn.get('''SELECT COUNT(*) from fts_db.dirtied_formats''')[0][0]
if not num_indexed: if not num_dirty:
self.dirty_existing() num_indexed = conn.get('''SELECT COUNT(*) from fts_db.books_text''')[0][0]
if not num_indexed:
needs_dirty = True
conn.fts_dbpath = dbpath
if needs_dirty:
self.dirty_existing()
def get_connection(self): def get_connection(self):
db = self.dbref() db = self.dbref()
if db is None: if db is None:
raise RuntimeError('db has been garbage collected') raise RuntimeError('db has been garbage collected')
ans = db.backend.get_connection() ans = db.backend.get_connection()
if ans.fts_dbpath is None: self.initialize(ans)
self.initialize(ans)
return ans return ans
def dirty_existing(self): def dirty_existing(self):