Implement writing series

This commit is contained in:
Kovid Goyal 2013-03-01 16:26:07 +05:30
parent 51bbe72ec2
commit b98186f77e
3 changed files with 74 additions and 9 deletions

View File

@ -19,6 +19,7 @@ from calibre.db.errors import NoSuchFormat
from calibre.db.fields import create_field from calibre.db.fields import create_field
from calibre.db.search import Search from calibre.db.search import Search
from calibre.db.tables import VirtualTable from calibre.db.tables import VirtualTable
from calibre.db.write import get_series_values
from calibre.db.lazy import FormatMetadata, FormatsList from calibre.db.lazy import FormatMetadata, FormatsList
from calibre.ebooks.metadata.book.base import Metadata from calibre.ebooks.metadata.book.base import Metadata
from calibre.ptempfile import (base_dir, PersistentTemporaryFile, from calibre.ptempfile import (base_dir, PersistentTemporaryFile,
@ -619,8 +620,30 @@ class Cache(object):
# TODO: Specialize title/authors to also update path # TODO: Specialize title/authors to also update path
# TODO: Handle updating caches used by composite fields # TODO: Handle updating caches used by composite fields
# TODO: Ensure the sort fields are updated for title/author/series? # TODO: Ensure the sort fields are updated for title/author/series?
dirtied = self.fields[name].writer.set_books( f = self.fields[name]
is_series = f.metadata['datatype'] == 'series'
if is_series:
bimap, simap = {}, {}
for k, v in book_id_to_val_map.iteritems():
if isinstance(v, basestring):
v, sid = get_series_values(v)
else:
v = sid = None
if name.startswith('#') and sid is None:
sid = 1.0 # The value will be set to 1.0 in the db table
bimap[k] = v
if sid is not None:
simap[k] = sid
book_id_to_val_map = bimap
dirtied = f.writer.set_books(
book_id_to_val_map, self.backend, allow_case_change=allow_case_change) book_id_to_val_map, self.backend, allow_case_change=allow_case_change)
if is_series and simap:
sf = self.fields[f.name+'_index']
dirtied |= sf.writer.set_books(simap, self.backend, allow_case_change=False)
return dirtied return dirtied
# }}} # }}}

View File

@ -75,7 +75,7 @@ class WritingTest(BaseTest):
test.name, old_sqlite_res, sqlite_res)) test.name, old_sqlite_res, sqlite_res))
del db del db
def test_one_one(self): def test_one_one(self): # {{{
'Test setting of values in one-one fields' 'Test setting of values in one-one fields'
tests = [self.create_test('#yesno', (True, False, 'true', 'false', None))] tests = [self.create_test('#yesno', (True, False, 'true', 'false', None))]
for name, getter, setter in ( for name, getter, setter in (
@ -114,8 +114,9 @@ class WritingTest(BaseTest):
tests.append(self.create_test(name, tuple(vals), getter, setter)) tests.append(self.create_test(name, tuple(vals), getter, setter))
self.run_tests(tests) self.run_tests(tests)
# }}}
def test_many_one_basic(self): def test_many_one_basic(self): # {{{
'Test the different code paths for writing to a many-one field' 'Test the different code paths for writing to a many-one field'
cl = self.cloned_library cl = self.cloned_library
cache = self.init_cache(cl) cache = self.init_cache(cl)
@ -159,8 +160,6 @@ class WritingTest(BaseTest):
self.assertEqual(tuple(map(f.for_book, (1,2,3))), ('Two', 'Two', 'three')) self.assertEqual(tuple(map(f.for_book, (1,2,3))), ('Two', 'Two', 'three'))
del cache2 del cache2
# TODO: Test different column types series, #series,
# Enum # Enum
self.assertFalse(cache.set_field('#enum', {1:'Not allowed'})) self.assertFalse(cache.set_field('#enum', {1:'Not allowed'}))
self.assertEqual(cache.set_field('#enum', {1:'One', 2:'One', 3:'Three'}), {1, 3}) self.assertEqual(cache.set_field('#enum', {1:'One', 2:'One', 3:'Three'}), {1, 3})
@ -183,6 +182,29 @@ class WritingTest(BaseTest):
self.assertEqual(c.field_for('#rating', i), val) self.assertEqual(c.field_for('#rating', i), val)
del cache2 del cache2
# Series
self.assertFalse(cache.set_field('series',
{1:'a series one', 2:'a series one'}, allow_case_change=False))
self.assertEqual(cache.set_field('series', {3:'Series [3]'}), set([3]))
self.assertEqual(cache.set_field('#series', {1:'Series', 3:'Series'}),
{1, 3})
self.assertEqual(cache.set_field('#series', {2:'Series [0]'}), set([2]))
cache2 = self.init_cache(cl)
for c in (cache, cache2):
for i, val in {1:'A Series One', 2:'A Series One', 3:'Series'}.iteritems():
self.assertEqual(c.field_for('series', i), val)
for i in (1, 2, 3):
self.assertEqual(c.field_for('#series', i), 'Series')
for i, val in {1:2, 2:1, 3:3}.iteritems():
self.assertEqual(c.field_for('series_index', i), val)
for i, val in {1:1, 2:0, 3:1}.iteritems():
self.assertEqual(c.field_for('#series_index', i), val)
del cache2
# }}}
def tests(): def tests():
return unittest.TestLoader().loadTestsFromTestCase(WritingTest) return unittest.TestLoader().loadTestsFromTestCase(WritingTest)

View File

@ -7,6 +7,7 @@ __license__ = 'GPL v3'
__copyright__ = '2013, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2013, Kovid Goyal <kovid at kovidgoyal.net>'
__docformat__ = 'restructuredtext en' __docformat__ = 'restructuredtext en'
import re
from functools import partial from functools import partial
from datetime import datetime from datetime import datetime
@ -29,6 +30,21 @@ def single_text(x):
x = x.strip() x = x.strip()
return x if x else None return x if x else None
series_index_pat = re.compile(r'(.*)\s+\[([.0-9]+)\]$')
def get_series_values(val):
if not val:
return (val, None)
match = series_index_pat.match(val.strip())
if match is not None:
idx = match.group(2)
try:
idx = float(idx)
return (match.group(1).strip(), idx)
except:
pass
return (val, None)
def multiple_text(sep, x): def multiple_text(sep, x):
if x is None: if x is None:
return () return ()
@ -151,7 +167,7 @@ def custom_series_index(book_id_val_map, db, field, *args):
if sequence: if sequence:
db.conn.executemany('UPDATE %s SET %s=? WHERE book=? AND value=?'%( db.conn.executemany('UPDATE %s SET %s=? WHERE book=? AND value=?'%(
field.metadata['table'], field.metadata['column']), sequence) field.metadata['table'], field.metadata['column']), sequence)
return {s[0] for s in sequence} return {s[1] for s in sequence}
# }}} # }}}
# Many-One fields {{{ # Many-One fields {{{
@ -167,9 +183,10 @@ def many_one(book_id_val_map, db, field, allow_case_change, *args):
m = field.metadata m = field.metadata
table = field.table table = field.table
dt = m['datatype'] dt = m['datatype']
is_custom_series = dt == 'series' and table.name.startswith('#')
# Map values to their canonical form for later comparison # Map values to their canonical form for later comparison
kmap = safe_lower if dt == 'text' else lambda x:x kmap = safe_lower if dt in {'text', 'series'} else lambda x:x
# Ignore those items whose value is the same as the current value # Ignore those items whose value is the same as the current value
no_changes = {k:nval for k, nval in book_id_val_map.iteritems() if no_changes = {k:nval for k, nval in book_id_val_map.iteritems() if
@ -215,9 +232,12 @@ def many_one(book_id_val_map, db, field, allow_case_change, *args):
changed_items = {k:book_id_item_id_map[k] for k in updated if changed_items = {k:book_id_item_id_map[k] for k in updated if
book_id_item_id_map[k] is not None} book_id_item_id_map[k] is not None}
def sql_update(imap): def sql_update(imap):
db.conn.executemany( sql = (
'DELETE FROM {0} WHERE book=?; INSERT INTO {0}(book,{1},extra) VALUES(?, ?, 1.0)'
if is_custom_series else
'DELETE FROM {0} WHERE book=?; INSERT INTO {0}(book,{1}) VALUES(?, ?)' 'DELETE FROM {0} WHERE book=?; INSERT INTO {0}(book,{1}) VALUES(?, ?)'
.format(link_table, m['link_column']), )
db.conn.executemany(sql.format(link_table, m['link_column']),
tuple((book_id, book_id, item_id) for book_id, item_id in tuple((book_id, book_id, item_id) for book_id, item_id in
imap.iteritems())) imap.iteritems()))