Ensure initialization is thread safe

This commit is contained in:
Kovid Goyal 2022-05-02 08:11:06 +05:30
parent 958625f660
commit 368bba0ac3
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -3,15 +3,17 @@
# License: GPL v3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net> # License: GPL v3 Copyright: 2022, Kovid Goyal <kovid at kovidgoyal.net>
import builtins, apsw import apsw
import builtins
import hashlib import hashlib
import os import os
import sys import sys
from contextlib import suppress from contextlib import suppress
from threading import Lock
from calibre.utils.date import EPOCH, utcnow
from calibre.db import FTSQueryError from calibre.db import FTSQueryError
from calibre.db.annotations import unicode_normalize from calibre.db.annotations import unicode_normalize
from calibre.utils.date import EPOCH, utcnow
from .pool import Pool from .pool import Pool
from .schema_upgrade import SchemaUpgrade from .schema_upgrade import SchemaUpgrade
@ -32,18 +34,24 @@ class FTS:
def __init__(self, dbref): def __init__(self, dbref):
self.dbref = dbref self.dbref = dbref
self.pool = Pool(dbref) self.pool = Pool(dbref)
self.init_lock = Lock()
def initialize(self, conn): def initialize(self, conn):
needs_dirty = False
with self.init_lock:
if conn.fts_dbpath is None:
main_db_path = os.path.abspath(conn.db_filename('main')) main_db_path = os.path.abspath(conn.db_filename('main'))
dbpath = os.path.join(os.path.dirname(main_db_path), 'full-text-search.db') dbpath = os.path.join(os.path.dirname(main_db_path), 'full-text-search.db')
conn.execute(f'ATTACH DATABASE "{dbpath}" AS fts_db') conn.execute(f'ATTACH DATABASE "{dbpath}" AS fts_db')
SchemaUpgrade(conn) SchemaUpgrade(conn)
conn.fts_dbpath = dbpath
conn.execute('UPDATE fts_db.dirtied_formats SET in_progress=FALSE WHERE in_progress=TRUE') conn.execute('UPDATE fts_db.dirtied_formats SET in_progress=FALSE WHERE in_progress=TRUE')
num_dirty = conn.get('''SELECT COUNT(*) from fts_db.dirtied_formats''')[0][0] num_dirty = conn.get('''SELECT COUNT(*) from fts_db.dirtied_formats''')[0][0]
if not num_dirty: if not num_dirty:
num_indexed = conn.get('''SELECT COUNT(*) from fts_db.books_text''')[0][0] num_indexed = conn.get('''SELECT COUNT(*) from fts_db.books_text''')[0][0]
if not num_indexed: if not num_indexed:
needs_dirty = True
conn.fts_dbpath = dbpath
if needs_dirty:
self.dirty_existing() self.dirty_existing()
def get_connection(self): def get_connection(self):
@ -51,7 +59,6 @@ class FTS:
if db is None: if db is None:
raise RuntimeError('db has been garbage collected') raise RuntimeError('db has been garbage collected')
ans = db.backend.get_connection() ans = db.backend.get_connection()
if ans.fts_dbpath is None:
self.initialize(ans) self.initialize(ans)
return ans return ans