More work on fts indexing

This commit is contained in:
Kovid Goyal 2022-02-19 13:09:04 +05:30
parent 2c4891b26d
commit d009e10942
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
5 changed files with 67 additions and 18 deletions

View File

@ -930,17 +930,29 @@ class DB:
def enable_fts(self, dbref=None): def enable_fts(self, dbref=None):
enabled = dbref is not None enabled = dbref is not None
if enabled != self.prefs['fts_enabled']: self.prefs['fts_enabled'] = enabled
self.prefs['fts_enabled'] = enabled self.initialize_fts(dbref)
self.initialize_fts(dbref) if self.fts is not None:
if self.fts is not None: self.fts.dirty_existing()
self.fts.dirty_existing()
return self.fts return self.fts
@property @property
def fts_enabled(self): def fts_enabled(self):
return getattr(self, 'fts', None) is not None return getattr(self, 'fts', None) is not None
@property
def fts_has_idle_workers(self):
return self.fts_enabled and self.fts.pool.num_of_idle_workers > 0
@property
def fts_num_of_workers(self):
return self.fts.pool.num_of_workers if self.fts_enabled else 0
@fts_num_of_workers.setter
def fts_num_of_workers(self, num):
if self.fts_enabled:
self.fts.num_of_workers = num
def get_next_fts_job(self): def get_next_fts_job(self):
return self.fts.get_next_fts_job() return self.fts.get_next_fts_job()

View File

@ -448,7 +448,13 @@ class Cache:
if not path or not is_fmt_ok(fmt): if not path or not is_fmt_ok(fmt):
self.backend.remove_dirty_fts(book_id, fmt) self.backend.remove_dirty_fts(book_id, fmt)
continue continue
with PersistentTemporaryFile(suffix=f'.{fmt.lower()}') as pt, open(path, 'rb') as src: try:
src = open(path, 'rb')
except OSError:
self.backend.remove_dirty_fts(book_id, fmt)
traceback.print_exc()
continue
with PersistentTemporaryFile(suffix=f'.{fmt.lower()}') as pt, src:
sz = 0 sz = 0
h = hashlib.sha1() h = hashlib.sha1()
while True: while True:
@ -458,13 +464,23 @@ class Cache:
sz += len(chunk) sz += len(chunk)
h.update(chunk) h.update(chunk)
pt.write(chunk) pt.write(chunk)
if self.backend.queue_fts_job(book_id, fmt, path, sz, h.hexdigest()): if self.backend.queue_fts_job(book_id, fmt, pt.name, sz, h.hexdigest()):
break if not self.backend.fts_has_idle_workers:
break
@write_api @write_api
def commit_fts_result(self, book_id, fmt, fmt_size, fmt_hash, text): def commit_fts_result(self, book_id, fmt, fmt_size, fmt_hash, text):
return self.backend.commit_fts_result(book_id, fmt, fmt_size, fmt_hash, text) return self.backend.commit_fts_result(book_id, fmt, fmt_size, fmt_hash, text)
@api
def set_fts_num_of_workers(self, num=None):
existing = self.backend.fts_num_of_workers
if num is not None and num != existing:
self.backend.fts_num_of_workers = num
if num > existing:
self.queue_next_fts_job()
return existing
# }}} # }}}
# Cache Layer API {{{ # Cache Layer API {{{
@ -1622,7 +1638,6 @@ class Cache:
try: try:
stream = stream_or_path if hasattr(stream_or_path, 'read') else lopen(stream_or_path, 'rb') stream = stream_or_path if hasattr(stream_or_path, 'read') else lopen(stream_or_path, 'rb')
size, fname = self._do_add_format(book_id, fmt, stream, name) size, fname = self._do_add_format(book_id, fmt, stream, name)
self._queue_next_fts_job()
finally: finally:
if needs_close: if needs_close:
stream.close() stream.close()
@ -1639,6 +1654,7 @@ class Cache:
run_plugins_on_postimport(dbapi or self, book_id, fmt) run_plugins_on_postimport(dbapi or self, book_id, fmt)
stream_or_path.close() stream_or_path.close()
self.queue_next_fts_job()
return True return True
@write_api @write_api

View File

@ -91,7 +91,7 @@ class FTS:
text_hash = '' text_hash = ''
if text: if text:
text_hash = hashlib.sha1(text.encode('utf-8')).hexdigest() text_hash = hashlib.sha1(text.encode('utf-8')).hexdigest()
for x in conn.get('SELECT id FROM fts_db.books_text WHERE book=? AND fmt=? AND text_hash=?', (book_id, fmt, text_hash)): for x in conn.get('SELECT id FROM fts_db.books_text WHERE book=? AND format=? AND text_hash=?', (book_id, fmt, text_hash)):
text = '' text = ''
break break
self.add_text(book_id, fmt, text, text_hash, fmt_size, fmt_hash) self.add_text(book_id, fmt, text, text_hash, fmt_size, fmt_hash)
@ -99,12 +99,12 @@ class FTS:
def queue_job(self, book_id, fmt, path, fmt_size, fmt_hash): def queue_job(self, book_id, fmt, path, fmt_size, fmt_hash):
conn = self.get_connection() conn = self.get_connection()
fmt = fmt.upper() fmt = fmt.upper()
for x in conn.get('SELECT id FROM fts_db.books_text WHERE book=? AND fmt=? AND format_size=? AND format_hash=?', ( for x in conn.get('SELECT id FROM fts_db.books_text WHERE book=? AND format=? AND format_size=? AND format_hash=?', (
book_id, fmt, fmt_size, fmt_hash)): book_id, fmt, fmt_size, fmt_hash)):
break break
else: else:
self.pool.add_job(book_id, fmt, path, fmt_size, fmt_hash) self.pool.add_job(book_id, fmt, path, fmt_size, fmt_hash)
conn.execute('UPDATE fts_db.dirtied_formats SET in_progress=TRUE WHERE book=? AND format=? LIMIT 1', (book_id, fmt)) conn.execute('UPDATE fts_db.dirtied_formats SET in_progress=TRUE WHERE book=? AND format=?', (book_id, fmt))
return True return True
self.remove_dirty(book_id, fmt) self.remove_dirty(book_id, fmt)
with suppress(OSError): with suppress(OSError):

View File

@ -52,12 +52,14 @@ class Worker(Thread):
self.jobs_queue = jobs_queue self.jobs_queue = jobs_queue
self.supervise_queue = supervise_queue self.supervise_queue = supervise_queue
self.keep_going = True self.keep_going = True
self.working = False
def run(self): def run(self):
while self.keep_going: while self.keep_going:
x = self.jobs_queue.get() x = self.jobs_queue.get()
if x is quit: if x is quit:
break break
self.working = True
try: try:
res = self.run_job(x) res = self.run_job(x)
if res is not None: if res is not None:
@ -66,6 +68,8 @@ class Worker(Thread):
tb = traceback.format_exc() tb = traceback.format_exc()
traceback.print_exc() traceback.print_exc()
self.supervise_queue.put(Result(x, tb)) self.supervise_queue.put(Result(x, tb))
finally:
self.working = False
def run_job(self, job): def run_job(self, job):
txtpath = job.path + '.txt' txtpath = job.path + '.txt'
@ -77,7 +81,9 @@ class Worker(Thread):
stdout=subprocess.DEVNULL, stderr=error, stdin=subprocess.DEVNULL, priority='low', stdout=subprocess.DEVNULL, stderr=error, stdin=subprocess.DEVNULL, priority='low',
) )
while self.keep_going: while self.keep_going:
p.wait(0.1) with suppress(subprocess.TimeoutExpired):
p.wait(0.1)
break
if p.returncode is None: if p.returncode is None:
p.kill() p.kill()
return return
@ -146,7 +152,12 @@ class Pool:
extra -= 1 extra -= 1
# external API {{{ # external API {{{
def set_num_of_workers(self, num): @property
def num_of_workers(self):
return len(self.workers)
@num_of_workers.setter
def num_of_workers(self, num):
self.initialize() self.initialize()
self.prune_dead_workers() self.prune_dead_workers()
num = max(1, num) num = max(1, num)
@ -156,6 +167,10 @@ class Pool:
elif num < self.workers: elif num < self.workers:
self.shrink_workers() self.shrink_workers()
@property
def num_of_idle_workers(self):
return sum(1 if w.working else 0 for w in self.workers)
def check_for_work(self): def check_for_work(self):
self.initialize() self.initialize()
self.supervise_queue.put(check_for_work) self.supervise_queue.put(check_for_work)

View File

@ -29,13 +29,19 @@ class FTSAPITest(BaseTest):
from calibre_extensions.sqlite_extension import set_ui_language from calibre_extensions.sqlite_extension import set_ui_language
set_ui_language('en') set_ui_language('en')
def wait_for_fts_to_finish(self, fts, timeout=10):
if fts.pool.initialized:
st = time.monotonic()
while fts.all_currently_dirty() and time.monotonic() - st < timeout:
fts.pool.supervisor_thread.join(0.01)
def test_fts_pool(self): def test_fts_pool(self):
cache = self.init_cache() cache = self.init_cache()
fts = cache.enable_fts(start_pool=True) fts = cache.enable_fts()
st = time.monotonic() self.wait_for_fts_to_finish(fts)
while fts.all_currently_dirty() and time.monotonic() - st < 2:
fts.pool.supervisor_thread.join(0.01)
self.assertFalse(fts.all_currently_dirty()) self.assertFalse(fts.all_currently_dirty())
cache.add_format(1, 'TXT', BytesIO(b'a test text'))
self.wait_for_fts_to_finish(fts)
def test_fts_triggers(self): def test_fts_triggers(self):
cache = self.init_cache() cache = self.init_cache()