API for per-library restrictions

This commit is contained in:
Kovid Goyal 2017-05-18 14:37:00 +05:30
parent 22d77263f8
commit e457f12194
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
4 changed files with 52 additions and 13 deletions

View File

@ -988,12 +988,17 @@ class Cache(object):
return self._search_api(self, query, restriction, virtual_fields=virtual_fields, book_ids=book_ids)
@read_api
def books_in_virtual_library(self, vl):
def books_in_virtual_library(self, vl, search_restriction=None):
' Return the set of books in the specified virtual library '
vl = self._pref('virtual_libraries', {}).get(vl) if vl else None
if vl is None:
if not vl and not search_restriction:
return self.all_book_ids()
return frozenset(self._search('', vl))
# We utilize the search restriction cache to speed this up
if vl:
if search_restriction:
return frozenset(self._search('', vl) & self._search('', search_restriction))
return frozenset(self._search('', vl))
return frozenset(self._search('', search_restriction))
@api
def get_categories(self, sort='name', book_ids=None, already_fixed=None,

View File

@ -10,7 +10,7 @@ from threading import Lock
from calibre.srv.auth import AuthController
from calibre.srv.errors import HTTPForbidden
from calibre.srv.library_broker import LibraryBroker
from calibre.srv.library_broker import LibraryBroker, path_for_db
from calibre.srv.routes import Router
from calibre.srv.users import UserManager
from calibre.utils.date import utcnow
@ -76,6 +76,21 @@ class Context(object):
raise HTTPForbidden('The user {} is not allowed to access any libraries on this server'.format(request_data.username))
return dict(allowed_libraries), next(allowed_libraries.iterkeys())
def restriction_for(self, request_data, db):
return self.user_manager.library_restriction(request_data.username, path_for_db(db))
def has_id(self, request_data, db, book_id):
restriction = self.restriction_for(request_data, db)
if restriction:
return book_id in db.search('', restriction=restriction)
return db.has_id(book_id)
def allowed_book_ids(self, request_data, db):
restriction = self.restriction_for(request_data, db)
if restriction:
return frozenset(db.search('', restriction=restriction))
return db.all_book_ids()
def check_for_write_access(self, request_data):
if not request_data.username:
if request_data.is_local_connection and self.opts.local_write:
@ -84,9 +99,12 @@ class Context(object):
if self.user_manager.is_readonly(request_data.username):
raise HTTPForbidden('The user {} does not have permission to make changes'.format(request_data.username))
def get_effective_book_ids(self, db, request_data, vl):
return db.books_in_virtual_library(vl, self.restriction_for(request_data, db))
def get_categories(self, request_data, db, sort='name', first_letter_sort=True, vl=''):
restrict_to_ids = db.books_in_virtual_library(vl)
key = (restrict_to_ids, sort, first_letter_sort)
restrict_to_ids = self.get_effective_book_ids(db, request_data, vl)
key = restrict_to_ids, sort, first_letter_sort
with self.lock:
cache = self.library_broker.category_caches[db.server_library_id]
old = cache.pop(key, None)
@ -100,8 +118,8 @@ class Context(object):
return old[1]
def get_tag_browser(self, request_data, db, opts, render, vl=''):
restrict_to_ids = db.books_in_virtual_library(vl)
key = (restrict_to_ids, opts)
restrict_to_ids = self.get_effective_book_ids(db, request_data, vl)
key = restrict_to_ids, opts
with self.lock:
cache = self.library_broker.category_caches[db.server_library_id]
old = cache.pop(key, None)
@ -118,13 +136,13 @@ class Context(object):
return old[1]
def search(self, request_data, db, query, vl=''):
restrict_to_ids = self.get_effective_book_ids(db, request_data, vl)
key = query, restrict_to_ids
with self.lock:
cache = self.library_broker.search_caches[db.server_library_id]
vl = db.pref('virtual_libraries', {}).get(vl) or ''
key = query, vl
old = cache.pop(key, None)
if old is None or old[0] < db.clear_search_cache_count:
matches = db.search(query, restriction=vl)
matches = db.search(query, book_ids=restrict_to_ids)
cache[key] = old = (db.clear_search_cache_count, matches)
if len(cache) > self.SEARCH_CACHE_SIZE:
cache.popitem(last=False)

View File

@ -138,6 +138,10 @@ def load_gui_libraries(gprefs=None):
return sorted(stats, key=stats.get, reverse=True)
def path_for_db(db):
return db.new_api.backend.library_path
class GuiLibraryBroker(LibraryBroker):
def __init__(self, db):
@ -188,7 +192,7 @@ class GuiLibraryBroker(LibraryBroker):
def gui_library_changed(self, db, olddb=None):
# Must be called with lock held
original_path = db.backend.library_path
original_path = path_for_db(db)
newloc = canonicalize_path(original_path)
for library_id, path in self.lmap.iteritems():
if samefile(newloc, path):
@ -201,7 +205,7 @@ class GuiLibraryBroker(LibraryBroker):
self.library_name_map[library_id] = os.path.basename(original_path)
self.loaded_dbs[library_id] = db
db.new_api.server_library_id = library_id
if olddb is not None and samefile(olddb.backend.library_path, db.backend.library_path):
if olddb is not None and samefile(path_for_db(olddb), path_for_db(db)):
# This happens after a restore database, for example
olddb.close(), olddb.break_cycles()
self._prune_loaded_dbs()

View File

@ -28,8 +28,12 @@ def parse_restriction(raw):
r = load_json(raw)
if not isinstance(r, dict):
r = {}
lr = r.get('library_restrictions', {})
if not isinstance(lr, dict):
lr = {}
r['allowed_library_names'] = frozenset(map(lambda x: x.lower(), r.get('allowed_library_names', ())))
r['blocked_library_names'] = frozenset(map(lambda x: x.lower(), r.get('blocked_library_names', ())))
r['library_restrictions'] = {k.lower(): v or '' for k, v in lr.iteritems()}
return r
@ -39,6 +43,7 @@ def serialize_restriction(r):
v = r.get(x)
if v:
ans[x] = list(v)
ans['library_restrictions'] = {l.lower(): v or '' for l, v in r.get('library_restrictions', {}).iteritems()}
return json.dumps(ans)
@ -242,3 +247,10 @@ class UserManager(object):
self._restrictions.pop(username, None)
self.conn.cursor().execute(
'UPDATE users SET restriction=? WHERE name=?', (serialize_restriction(restrictions), username))
def library_restriction(self, username, library_path):
r = self.restrictions(username)
if r is None:
return ''
library_name = os.path.basename(library_path).lower()
return r['library_restrictions'].get(library_name) or ''