When renaming return id map as well

This commit is contained in:
Kovid Goyal 2013-07-16 15:09:03 +05:30
parent 0c6d820f2b
commit 7819f37a45
3 changed files with 20 additions and 13 deletions

View File

@ -1229,8 +1229,11 @@ class Cache(object):
except AttributeError:
raise ValueError('Cannot rename items for one-one fields: %s' % field)
affected_books = set()
id_map = {}
for item_id, new_name in item_id_to_new_name_map.iteritems():
affected_books.update(func(item_id, new_name, self.backend))
books, new_id = func(item_id, new_name, self.backend)
affected_books.update(books)
id_map[item_id] = new_id
if affected_books:
if field == 'authors':
self._set_field('author_sort', # also marks as dirty
@ -1238,7 +1241,7 @@ class Cache(object):
self._update_path(affected_books, mark_as_dirtied=False)
else:
self._mark_as_dirty(affected_books)
return affected_books
return affected_books, id_map
@write_api
def remove_items(self, field, item_ids):

View File

@ -227,12 +227,14 @@ class ManyToOneTable(Table):
existing_item = rmap.get(icu_lower(new_name), None)
table, col, lcol = self.metadata['table'], self.metadata['column'], self.metadata['link_column']
affected_books = self.col_book_map.get(item_id, set())
new_id = item_id
if existing_item is None or existing_item == item_id:
# A simple rename will do the trick
self.id_map[item_id] = new_name
db.conn.execute('UPDATE {0} SET {1}=? WHERE id=?'.format(table, col), (new_name, item_id))
else:
# We have to replace
new_id = existing_item
self.id_map.pop(item_id, None)
books = self.col_book_map.pop(item_id, set())
for book_id in books:
@ -243,7 +245,7 @@ class ManyToOneTable(Table):
# handle that in this context.
db.conn.execute('UPDATE {0} SET {1}=? WHERE {1}=?; DELETE FROM {2} WHERE id=?'.format(
self.link_table, lcol, table), (existing_item, item_id, item_id))
return affected_books
return affected_books, new_id
class ManyToManyTable(ManyToOneTable):
@ -311,12 +313,14 @@ class ManyToManyTable(ManyToOneTable):
existing_item = rmap.get(icu_lower(new_name), None)
table, col, lcol = self.metadata['table'], self.metadata['column'], self.metadata['link_column']
affected_books = self.col_book_map.get(item_id, set())
new_id = item_id
if existing_item is None or existing_item == item_id:
# A simple rename will do the trick
self.id_map[item_id] = new_name
db.conn.execute('UPDATE {0} SET {1}=? WHERE id=?'.format(table, col), (new_name, item_id))
else:
# We have to replace
new_id = existing_item
self.id_map.pop(item_id, None)
books = self.col_book_map.pop(item_id, set())
# Replacing item_id with existing_item could cause the same id to
@ -329,7 +333,7 @@ class ManyToManyTable(ManyToOneTable):
(book_id, existing_item) for book_id in books])
db.conn.execute('UPDATE {0} SET {1}=? WHERE {1}=?; DELETE FROM {2} WHERE id=?'.format(
self.link_table, lcol, table), (existing_item, item_id, item_id))
return affected_books
return affected_books, new_id
class AuthorsTable(ManyToManyTable):

View File

@ -481,9 +481,9 @@ class WritingTest(BaseTest):
cache = self.init_cache(cl)
# Check that renaming authors updates author sort and path
a = {v:k for k, v in cache.get_id_map('authors').iteritems()}['Unknown']
self.assertEqual(cache.rename_items('authors', {a:'New Author'}), {3})
self.assertEqual(cache.rename_items('authors', {a:'New Author'})[0], {3})
a = {v:k for k, v in cache.get_id_map('authors').iteritems()}['Author One']
self.assertEqual(cache.rename_items('authors', {a:'Author Two'}), {1, 2})
self.assertEqual(cache.rename_items('authors', {a:'Author Two'})[0], {1, 2})
for c in (cache, self.init_cache(cl)):
self.assertEqual(c.all_field_names('authors'), {'New Author', 'Author Two'})
self.assertEqual(c.field_for('author_sort', 3), 'Author, New')
@ -493,26 +493,26 @@ class WritingTest(BaseTest):
t = {v:k for k, v in cache.get_id_map('tags').iteritems()}['Tag One']
# Test case change
self.assertEqual(cache.rename_items('tags', {t:'tag one'}), {1, 2})
self.assertEqual(cache.rename_items('tags', {t:'tag one'}), ({1, 2}, {t:t}))
for c in (cache, self.init_cache(cl)):
self.assertEqual(c.all_field_names('tags'), {'tag one', 'Tag Two', 'News'})
self.assertEqual(set(c.field_for('tags', 1)), {'tag one', 'News'})
self.assertEqual(set(c.field_for('tags', 2)), {'tag one', 'Tag Two'})
# Test new name
self.assertEqual(cache.rename_items('tags', {t:'t1'}), {1,2})
self.assertEqual(cache.rename_items('tags', {t:'t1'})[0], {1,2})
for c in (cache, self.init_cache(cl)):
self.assertEqual(c.all_field_names('tags'), {'t1', 'Tag Two', 'News'})
self.assertEqual(set(c.field_for('tags', 1)), {'t1', 'News'})
self.assertEqual(set(c.field_for('tags', 2)), {'t1', 'Tag Two'})
# Test rename to existing
self.assertEqual(cache.rename_items('tags', {t:'Tag Two'}), {1,2})
self.assertEqual(cache.rename_items('tags', {t:'Tag Two'})[0], {1,2})
for c in (cache, self.init_cache(cl)):
self.assertEqual(c.all_field_names('tags'), {'Tag Two', 'News'})
self.assertEqual(set(c.field_for('tags', 1)), {'Tag Two', 'News'})
self.assertEqual(set(c.field_for('tags', 2)), {'Tag Two'})
# Test on a custom column
t = {v:k for k, v in cache.get_id_map('#tags').iteritems()}['My Tag One']
self.assertEqual(cache.rename_items('#tags', {t:'My Tag Two'}), {2})
self.assertEqual(cache.rename_items('#tags', {t:'My Tag Two'})[0], {2})
for c in (cache, self.init_cache(cl)):
self.assertEqual(c.all_field_names('#tags'), {'My Tag Two'})
self.assertEqual(set(c.field_for('#tags', 2)), {'My Tag Two'})
@ -520,14 +520,14 @@ class WritingTest(BaseTest):
# Test a Many-one field
s = {v:k for k, v in cache.get_id_map('series').iteritems()}['A Series One']
# Test case change
self.assertEqual(cache.rename_items('series', {s:'a series one'}), {1, 2})
self.assertEqual(cache.rename_items('series', {s:'a series one'}), ({1, 2}, {s:s}))
for c in (cache, self.init_cache(cl)):
self.assertEqual(c.all_field_names('series'), {'a series one'})
self.assertEqual(c.field_for('series', 1), 'a series one')
self.assertEqual(c.field_for('series_index', 1), 2.0)
# Test new name
self.assertEqual(cache.rename_items('series', {s:'series'}), {1, 2})
self.assertEqual(cache.rename_items('series', {s:'series'})[0], {1, 2})
for c in (cache, self.init_cache(cl)):
self.assertEqual(c.all_field_names('series'), {'series'})
self.assertEqual(c.field_for('series', 1), 'series')
@ -536,7 +536,7 @@ class WritingTest(BaseTest):
s = {v:k for k, v in cache.get_id_map('#series').iteritems()}['My Series One']
# Test custom column with rename to existing
self.assertEqual(cache.rename_items('#series', {s:'My Series Two'}), {2})
self.assertEqual(cache.rename_items('#series', {s:'My Series Two'})[0], {2})
for c in (cache, self.init_cache(cl)):
self.assertEqual(c.all_field_names('#series'), {'My Series Two'})
self.assertEqual(c.field_for('#series', 2), 'My Series Two')