Fix TableRow implementation and when iterating over ids iterate in ascending order to match legacy interface

This commit is contained in:
Kovid Goyal 2013-05-02 12:31:56 +05:30
parent 403b12bb82
commit c01b8ea033
3 changed files with 44 additions and 7 deletions

View File

@ -11,6 +11,7 @@ from functools import partial
from calibre.db.backend import DB from calibre.db.backend import DB
from calibre.db.cache import Cache from calibre.db.cache import Cache
from calibre.db.categories import CATEGORY_SORTS
from calibre.db.view import View from calibre.db.view import View
from calibre.utils.date import utcnow from calibre.utils.date import utcnow
@ -20,6 +21,10 @@ class LibraryDatabase(object):
PATH_LIMIT = DB.PATH_LIMIT PATH_LIMIT = DB.PATH_LIMIT
WINDOWS_LIBRARY_PATH_LIMIT = DB.WINDOWS_LIBRARY_PATH_LIMIT WINDOWS_LIBRARY_PATH_LIMIT = DB.WINDOWS_LIBRARY_PATH_LIMIT
CATEGORY_SORTS = CATEGORY_SORTS
MATCH_TYPE = ('any', 'all')
CUSTOM_DATA_TYPES = frozenset(['rating', 'text', 'comments', 'datetime',
'int', 'float', 'bool', 'series', 'composite', 'enumeration'])
@classmethod @classmethod
def exists_at(cls, path): def exists_at(cls, path):
@ -148,3 +153,7 @@ class LibraryDatabase(object):
os.makedirs(path) os.makedirs(path)
return path return path
def __iter__(self):
for row in self.data.iterall():
yield row

View File

@ -82,6 +82,7 @@ class LegacyTest(BaseTest):
# }}} # }}}
def test_legacy_getters(self): # {{{ def test_legacy_getters(self): # {{{
' Test various functions to get individual bits of metadata '
old = self.init_old() old = self.init_old()
getters = ('path', 'abspath', 'title', 'authors', 'series', getters = ('path', 'abspath', 'title', 'authors', 'series',
'publisher', 'author_sort', 'authors', 'comments', 'publisher', 'author_sort', 'authors', 'comments',
@ -89,11 +90,29 @@ class LegacyTest(BaseTest):
'timestamp', 'uuid', 'pubdate', 'ondevice', 'timestamp', 'uuid', 'pubdate', 'ondevice',
'metadata_last_modified', 'languages') 'metadata_last_modified', 'languages')
oldvals = {g:tuple(getattr(old, g)(x) for x in xrange(3)) + tuple(getattr(old, g)(x, True) for x in (1,2,3)) for g in getters} oldvals = {g:tuple(getattr(old, g)(x) for x in xrange(3)) + tuple(getattr(old, g)(x, True) for x in (1,2,3)) for g in getters}
old_rows = {tuple(r)[:5] for r in old}
old.close() old.close()
db = self.init_legacy() db = self.init_legacy()
newvals = {g:tuple(getattr(db, g)(x) for x in xrange(3)) + tuple(getattr(db, g)(x, True) for x in (1,2,3)) for g in getters} newvals = {g:tuple(getattr(db, g)(x) for x in xrange(3)) + tuple(getattr(db, g)(x, True) for x in (1,2,3)) for g in getters}
new_rows = {tuple(r)[:5] for r in db}
for x in (oldvals, newvals): for x in (oldvals, newvals):
x['tags'] = tuple(set(y.split(',')) if y else y for y in x['tags']) x['tags'] = tuple(set(y.split(',')) if y else y for y in x['tags'])
self.assertEqual(oldvals, newvals) self.assertEqual(oldvals, newvals)
self.assertEqual(old_rows, new_rows)
# }}} # }}}
def test_legacy_coverage(self): # {{{
' Check that the emulation of the legacy interface is (almost) total '
cl = self.cloned_library
db = self.init_old(cl)
ndb = self.init_legacy()
SKIP_ATTRS = {'TCat_Tag'}
for attr in dir(db):
if attr in SKIP_ATTRS:
continue
self.assertTrue(hasattr(ndb, attr), 'The attribute %s is missing' % attr)
# obj = getattr(db, attr)
# }}}

View File

@ -29,11 +29,12 @@ class MarkedVirtualField(object):
for book_id in candidates: for book_id in candidates:
yield self.marked_ids.get(book_id, default_value), {book_id} yield self.marked_ids.get(book_id, default_value), {book_id}
class TableRow(list): class TableRow(object):
def __init__(self, book_id, view): def __init__(self, book_id, view):
self.book_id = book_id self.book_id = book_id
self.view = weakref.ref(view) self.view = weakref.ref(view)
self.column_count = view.column_count
def __getitem__(self, obj): def __getitem__(self, obj):
view = self.view() view = self.view()
@ -43,6 +44,13 @@ class TableRow(list):
else: else:
return view._field_getters[obj](self.book_id) return view._field_getters[obj](self.book_id)
def __len__(self):
return self.column_count
def __iter__(self):
for i in xrange(self.column_count):
yield self[i]
def format_is_multiple(x, sep=',', repl=None): def format_is_multiple(x, sep=',', repl=None):
if not x: if not x:
return None return None
@ -67,6 +75,7 @@ class View(object):
self.search_restriction = self.base_restriction = '' self.search_restriction = self.base_restriction = ''
self.search_restriction_name = self.base_restriction_name = '' self.search_restriction_name = self.base_restriction_name = ''
self._field_getters = {} self._field_getters = {}
self.column_count = len(cache.backend.FIELD_MAP)
for col, idx in cache.backend.FIELD_MAP.iteritems(): for col, idx in cache.backend.FIELD_MAP.iteritems():
label, fmt = col, lambda x:x label, fmt = col, lambda x:x
func = { func = {
@ -107,7 +116,7 @@ class View(object):
fmt = partial(format_is_multiple, sep=sep) fmt = partial(format_is_multiple, sep=sep)
self._field_getters[idx] = partial(func, label, fmt=fmt) if func == self._get else func self._field_getters[idx] = partial(func, label, fmt=fmt) if func == self._get else func
self._map = tuple(self.cache.all_book_ids()) self._map = tuple(sorted(self.cache.all_book_ids()))
self._map_filtered = tuple(self._map) self._map_filtered = tuple(self._map)
def get_property(self, id_or_index, index_is_id=False, loc=-1): def get_property(self, id_or_index, index_is_id=False, loc=-1):
@ -124,21 +133,21 @@ class View(object):
return idx if index_is_id else self.index_to_id(idx) return idx if index_is_id else self.index_to_id(idx)
def __getitem__(self, row): def __getitem__(self, row):
return TableRow(self._map_filtered[row], self.cache) return TableRow(self._map_filtered[row], self)
def __len__(self): def __len__(self):
return len(self._map_filtered) return len(self._map_filtered)
def __iter__(self): def __iter__(self):
for book_id in self._map_filtered: for book_id in self._map_filtered:
yield self._data[book_id] yield TableRow(book_id, self)
def iterall(self): def iterall(self):
for book_id in self._map: for book_id in self.iterallids():
yield self[book_id] yield TableRow(book_id, self)
def iterallids(self): def iterallids(self):
for book_id in self._map: for book_id in sorted(self._map):
yield book_id yield book_id
def get_field_map_field(self, row, col, index_is_id=True): def get_field_map_field(self, row, col, index_is_id=True):