Finish up the logic for writing many-one fields

This commit is contained in:
Kovid Goyal 2013-02-28 22:55:55 +05:30
parent 3bb5796162
commit 8d649bc403
2 changed files with 112 additions and 27 deletions

View File

@ -115,6 +115,52 @@ class WritingTest(BaseTest):
self.run_tests(tests) self.run_tests(tests)
def test_many_one_basic(self):
'Test the different code paths for writing to a many-one field'
cl = self.cloned_library
cache = self.init_cache(cl)
f = cache.fields['publisher']
item_ids = {f.ids_for_book(1)[0], f.ids_for_book(2)[0]}
val = 'Changed'
self.assertEqual(cache.set_field('publisher', {1:val, 2:val}), {1, 2})
cache2 = self.init_cache(cl)
for book_id in (1, 2):
for c in (cache, cache2):
self.assertEqual(c.field_for('publisher', book_id), val)
self.assertFalse(item_ids.intersection(set(c.fields['publisher'].table.id_map)))
del cache2
self.assertFalse(cache.set_field('publisher', {1:val, 2:val}))
val = val.lower()
self.assertFalse(cache.set_field('publisher', {1:val, 2:val},
allow_case_change=False))
self.assertEqual(cache.set_field('publisher', {1:val, 2:val}), {1, 2})
cache2 = self.init_cache(cl)
for book_id in (1, 2):
for c in (cache, cache2):
self.assertEqual(c.field_for('publisher', book_id), val)
del cache2
self.assertEqual(cache.set_field('publisher', {1:'new', 2:'New'}), {1, 2})
self.assertEqual(cache.field_for('publisher', 1).lower(), 'new')
self.assertEqual(cache.field_for('publisher', 2).lower(), 'new')
self.assertEqual(cache.set_field('publisher', {1:None, 2:'NEW'}), {1, 2})
self.assertEqual(len(f.table.id_map), 1)
self.assertEqual(cache.set_field('publisher', {2:None}), {2})
self.assertEqual(len(f.table.id_map), 0)
cache2 = self.init_cache(cl)
self.assertEqual(len(cache2.fields['publisher'].table.id_map), 0)
del cache2
self.assertEqual(cache.set_field('publisher', {1:'one', 2:'two',
3:'three'}), {1, 2, 3})
self.assertEqual(cache.set_field('publisher', {1:''}), set([1]))
self.assertEqual(cache.set_field('publisher', {1:'two'}), set([1]))
self.assertEqual(tuple(map(f.for_book, (1,2,3))), ('two', 'two', 'three'))
self.assertEqual(cache.set_field('publisher', {1:'Two'}), {1, 2})
cache2 = self.init_cache(cl)
self.assertEqual(tuple(map(f.for_book, (1,2,3))), ('Two', 'Two', 'three'))
del cache2
# TODO: Test different column types
def tests(): def tests():
return unittest.TestLoader().loadTestsFromTestCase(WritingTest) return unittest.TestLoader().loadTestsFromTestCase(WritingTest)

View File

@ -156,56 +156,94 @@ def custom_series_index(book_id_val_map, db, field, *args):
# Many-One fields {{{ # Many-One fields {{{
def safe_lower(x):
try:
return icu_lower(x)
except (TypeError, ValueError, KeyError, AttributeError):
return x
def many_one(book_id_val_map, db, field, allow_case_change, *args): def many_one(book_id_val_map, db, field, allow_case_change, *args):
dirtied = set() dirtied = set()
m = field.metadata m = field.metadata
table = field.table
dt = m['datatype'] dt = m['datatype']
kmap = icu_lower if dt == 'text' else lambda x:x
rid_map = {kmap(v):k for k, v in field.table.id_map.iteritems()} # Map values to their canonical form for later comparison
book_id_item_id_map = {k:rid_map.get(kmap(v), None) if v is not None else kmap = safe_lower if dt == 'text' else lambda x:x
None for k, v in book_id_val_map.iteritems()}
# Ignore those items whose value is the same as the current value
no_changes = {k:nval for k, nval in book_id_val_map.iteritems() if
kmap(nval) == kmap(field.for_book(k, default_value=None))}
for book_id in no_changes:
del book_id_val_map[book_id]
# If we are allowed case changes check that none of the ignored items are
# case changes. If they are, update the item's case in the db.
if allow_case_change: if allow_case_change:
for book_id, item_id in book_id_item_id_map.iteritems(): for book_id, nval in no_changes.iteritems():
nval = book_id_val_map[book_id] if nval is not None and nval != field.for_book(
if (item_id is not None and nval != field.table.id_map[item_id]): book_id, default_value=None):
# Change of case # Change of case
item_id = table.book_col_map[book_id]
db.conn.execute('UPDATE %s SET %s=? WHERE id=?'%( db.conn.execute('UPDATE %s SET %s=? WHERE id=?'%(
m['table'], m['column']), (nval, item_id)) m['table'], m['column']), (nval, item_id))
field.table.id_map[item_id] = nval table.id_map[item_id] = nval
dirtied |= field.table.col_book_map[item_id] dirtied |= table.col_book_map[item_id]
deleted = {k:v for k, v in book_id_val_map.iteritems() if v is None} deleted = {k:v for k, v in book_id_val_map.iteritems() if v is None}
updated = {k:v for k, v in book_id_val_map.iteritems() if v is not None} updated = {k:v for k, v in book_id_val_map.iteritems() if v is not None}
link_table = table.link_table
if deleted: if deleted:
db.conn.executemany('DELETE FROM %s WHERE book=?'%m['link_table'], db.conn.executemany('DELETE FROM %s WHERE book=?'%link_table,
tuple((book_id,) for book_id in deleted)) tuple((book_id,) for book_id in deleted))
for book_id in deleted: for book_id in deleted:
field.table.book_col_map.pop(book_id, None) item_id = table.book_col_map.pop(book_id, None)
field.table.col_book_map.discard(book_id) if item_id is not None:
table.col_book_map[item_id].discard(book_id)
dirtied |= set(deleted) dirtied |= set(deleted)
if updated: if updated:
rid_map = {kmap(v):k for k, v in table.id_map.iteritems()}
book_id_item_id_map = {k:rid_map.get(kmap(v), None) for k, v in
book_id_val_map.iteritems()}
# items that dont yet exist
new_items = {k:v for k, v in updated.iteritems() if new_items = {k:v for k, v in updated.iteritems() if
book_id_item_id_map[k] is None} book_id_item_id_map[k] is None}
# items that already exist
changed_items = {k:book_id_item_id_map[k] for k in updated if changed_items = {k:book_id_item_id_map[k] for k in updated if
book_id_item_id_map[k] is not None} book_id_item_id_map[k] is not None}
def sql_update(imap): def sql_update(imap):
db.conn.executemany( db.conn.executemany(
'DELETE FROM {0} WHERE book=?; INSERT INTO {0}(book,{1}) VALUES(?, ?)' 'DELETE FROM {0} WHERE book=?; INSERT INTO {0}(book,{1}) VALUES(?, ?)'
.format(m['link_table'], m['link_column']), .format(link_table, m['link_column']),
tuple((book_id, book_id, item_id) for book_id, item_id in tuple((book_id, book_id, item_id) for book_id, item_id in
imap.iteritems())) imap.iteritems()))
if new_items: if new_items:
item_ids = {}
val_map = {}
for val in set(new_items.itervalues()):
lval = kmap(val)
if lval in val_map:
item_id = val_map[lval]
else:
db.conn.execute('INSERT INTO %s(%s) VALUES (?)'%(
m['table'], m['column']), (val,))
item_id = val_map[lval] = db.conn.last_insert_rowid()
item_ids[val] = item_id
table.id_map[item_id] = val
imap = {} imap = {}
for book_id, val in new_items.iteritems(): for book_id, val in new_items.iteritems():
db.conn.execute('INSERT INTO %s(%s) VALUES (?)'%( item_id = item_ids[val]
m['table'], m['column']), (val,)) old_item_id = table.book_col_map.get(book_id, None)
imap[book_id] = item_id = db.conn.last_insert_rowid() if old_item_id is not None:
field.table.id_map[item_id] = val table.col_book_map[old_item_id].discard(book_id)
field.table.col_book_map[item_id] = {book_id} if item_id not in table.col_book_map:
field.table.book_col_map[book_id] = item_id table.col_book_map[item_id] = set()
table.col_book_map[item_id].add(book_id)
table.book_col_map[book_id] = imap[book_id] = item_id
sql_update(imap) sql_update(imap)
dirtied |= set(imap) dirtied |= set(imap)
@ -213,24 +251,25 @@ def many_one(book_id_val_map, db, field, allow_case_change, *args):
imap = {} imap = {}
sql_update(changed_items) sql_update(changed_items)
for book_id, item_id in changed_items.iteritems(): for book_id, item_id in changed_items.iteritems():
old_item_id = field.table.book_col_map[book_id] old_item_id = table.book_col_map.get(book_id, None)
if old_item_id != item_id: if old_item_id != item_id:
field.table.book_col_map[book_id] = item_id table.book_col_map[book_id] = item_id
field.table.col_book_map[item_id].add(book_id) table.col_book_map[item_id].add(book_id)
field.table.col_book_map[old_item_id].discard(book_id) if old_item_id is not None:
table.col_book_map[old_item_id].discard(book_id)
imap[book_id] = item_id imap[book_id] = item_id
sql_update(imap) sql_update(imap)
dirtied |= set(imap) dirtied |= set(imap)
# Remove no longer used items # Remove no longer used items
remove = {item_id for item_id, book_ids in remove = {item_id for item_id in table.id_map if not
field.table.col_book_map.iteritems() if not book_ids} table.col_book_map.get(item_id, False)}
if remove: if remove:
db.conn.executemany('DELETE FROM %s WHERE id=?'%m['table'], db.conn.executemany('DELETE FROM %s WHERE id=?'%m['table'],
tuple((item_id,) for item_id in remove)) tuple((item_id,) for item_id in remove))
for item_id in remove: for item_id in remove:
del field.table.id_map[item_id] del table.id_map[item_id]
del field.table.col_book_map[item_id] table.col_book_map.pop(item_id, None)
return dirtied return dirtied
# }}} # }}}