Add a case-sensitive version of the get_item_id* API

This is an order of magnitude faster for large DBs as the
string comparison is done in C by sqlite.
This commit is contained in:
Kovid Goyal 2024-09-15 06:43:26 +05:30
parent 68c4f734f7
commit f94fbc113a
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 48 additions and 15 deletions

View File

@ -946,22 +946,23 @@ class Cache:
return self.fields[field].table.id_map[item_id] return self.fields[field].table.id_map[item_id]
@read_api @read_api
def get_item_id(self, field, item_name): def get_item_id(self, field, item_name, case_sensitive=False):
''' Return the item id for item_name (case-insensitive) or None if not found. ''' Return the item id for item_name or None if not found.
This function is very slow if doing lookups for multiple names use either get_item_ids() or get_item_name_map(). ''' This function is very slow if doing lookups for multiple names use either get_item_ids() or get_item_name_map().
q = icu_lower(item_name) Similarly, case sensitive lookups are faster than case insensitive ones. '''
try: field = self.fields[field]
for item_id, item_val in self.fields[field].table.id_map.items(): if hasattr(field, 'item_ids_for_names'):
if icu_lower(item_val) == q: d = field.item_ids_for_names(self.backend, (item_name,), case_sensitive)
return item_id for v in d.values():
except KeyError: return v
return None
@read_api @read_api
def get_item_ids(self, field, item_names): def get_item_ids(self, field, item_names, case_sensitive=False):
' Return the item id for item_name (case-insensitive) ' ' Return a dict mapping item_name to the item id or None '
rmap = {icu_lower(v) if isinstance(v, str) else v:k for k, v in iteritems(self.fields[field].table.id_map)} field = self.fields[field]
return {name:rmap.get(icu_lower(name) if isinstance(name, str) else name, None) for name in item_names} if hasattr(field, 'item_ids_for_names'):
return field.item_ids_for_names(self.backend, item_names, case_sensitive)
return dict.fromkeys(item_names)
@read_api @read_api
def get_item_name_map(self, field, normalize_func=None): def get_item_name_map(self, field, normalize_func=None):

View File

@ -9,6 +9,7 @@ import sys
from collections import Counter, defaultdict from collections import Counter, defaultdict
from functools import partial from functools import partial
from threading import Lock from threading import Lock
from typing import Iterable
from calibre.db.tables import MANY_MANY, MANY_ONE, ONE_ONE, null from calibre.db.tables import MANY_MANY, MANY_ONE, ONE_ONE, null
from calibre.db.utils import atof, force_to_bool from calibre.db.utils import atof, force_to_bool
@ -531,6 +532,9 @@ class ManyToOneField(Field):
except KeyError: except KeyError:
raise InvalidLinkTable(self.name) raise InvalidLinkTable(self.name)
def item_ids_for_names(self, db, item_names: Iterable[str], case_sensitive: bool = False) -> dict[str, int]:
return self.table.item_ids_for_names(db, item_names, case_sensitive)
class ManyToManyField(Field): class ManyToManyField(Field):
@ -540,6 +544,9 @@ class ManyToManyField(Field):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
Field.__init__(self, *args, **kwargs) Field.__init__(self, *args, **kwargs)
def item_ids_for_names(self, db, item_names: Iterable[str], case_sensitive: bool = False) -> dict[str, int]:
return self.table.item_ids_for_names(db, item_names, case_sensitive)
def for_book(self, book_id, default_value=None): def for_book(self, book_id, default_value=None):
ids = self.table.book_col_map.get(book_id, ()) ids = self.table.book_col_map.get(book_id, ())
if ids: if ids:

View File

@ -8,6 +8,7 @@ __docformat__ = 'restructuredtext en'
import numbers import numbers
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Iterable
from calibre.ebooks.metadata import author_to_author_sort from calibre.ebooks.metadata import author_to_author_sort
from calibre.utils.date import UNDEFINED_DATE, parse_date, utc_tz from calibre.utils.date import UNDEFINED_DATE, parse_date, utc_tz
@ -263,6 +264,26 @@ class ManyToOneTable(Table):
tuple((main_id, x) for x in v)) tuple((main_id, x) for x in v))
db.delete_category_items(self.name, self.metadata['table'], item_map) db.delete_category_items(self.name, self.metadata['table'], item_map)
def item_ids_for_names(self, db, item_names: Iterable[str], case_sensitive: bool = False) -> dict[str, int]:
item_names = tuple(item_names)
if case_sensitive:
colname = self.metadata['column']
if len(item_names) == 1:
iid = db.get(f'SELECT id FROM {self.metadata["table"]} WHERE {colname} = ?', ((item_names[0],)), all=False)
return {item_names[0]: iid}
inq = ('?,' * len(item_names))[:-1]
ans = dict.fromkeys(item_names)
ans.update(db.get(f'SELECT {colname}, id FROM {self.metadata["table"]} WHERE {colname} IN ({inq})', item_names))
return ans
if len(item_names) == 1:
q = icu_lower(item_names[0])
for iid, name in self.id_map.items():
if icu_lower(name) == q:
return {item_names[0]: iid}
return {item_names[0]: iid}
rmap = {icu_lower(v) if isinstance(v, str) else v:k for k, v in self.id_map.items()}
return {name: rmap.get(icu_lower(name) if isinstance(name, str) else name, None) for name in item_names}
def remove_books(self, book_ids, db): def remove_books(self, book_ids, db):
clean = set() clean = set()
for book_id in book_ids: for book_id in book_ids:

View File

@ -497,8 +497,12 @@ class WritingTest(BaseTest):
# auto-generated authors sort # auto-generated authors sort
mi = Metadata('empty', ['a1', 'a2']) mi = Metadata('empty', ['a1', 'a2'])
cache.set_metadata(1, mi) cache.set_metadata(1, mi)
self.assertEqual(cache.get_item_ids('authors', ('a1', 'a2')), cache.get_item_ids('authors', ('a1', 'a2'), case_sensitive=True))
self.assertEqual(
set(cache.get_item_ids('authors', ('A1', 'a2')).values()),
set(cache.get_item_ids('authors', ('a1', 'a2'), case_sensitive=True).values()))
self.assertEqual('a1 & a2', cache.field_for('author_sort', 1)) self.assertEqual('a1 & a2', cache.field_for('author_sort', 1))
cache.set_sort_for_authors({cache.get_item_id('authors', 'a1'): 'xy'}) cache.set_sort_for_authors({cache.get_item_id('authors', 'a1', case_sensitive=True): 'xy'})
self.assertEqual('xy & a2', cache.field_for('author_sort', 1)) self.assertEqual('xy & a2', cache.field_for('author_sort', 1))
mi = Metadata('empty', ['a1']) mi = Metadata('empty', ['a1'])
cache.set_metadata(1, mi) cache.set_metadata(1, mi)