Implement per-user library access restrictions

This commit is contained in:
Kovid Goyal 2017-04-15 12:08:12 +05:30
parent 6549baf27e
commit fa0fa3cba9
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
5 changed files with 114 additions and 13 deletions

View File

@ -50,3 +50,9 @@ class HTTPBadRequest(HTTPSimpleResponse):
def __init__(self, message, close_connection=False): def __init__(self, message, close_connection=False):
HTTPSimpleResponse.__init__(self, httplib.BAD_REQUEST, message, close_connection) HTTPSimpleResponse.__init__(self, httplib.BAD_REQUEST, message, close_connection)
class HTTPForbidden(HTTPSimpleResponse):
def __init__(self, http_message='', close_connection=True):
HTTPSimpleResponse.__init__(self, httplib.FORBIDDEN, http_message, close_connection)

View File

@ -1,19 +1,18 @@
#!/usr/bin/env python2 #!/usr/bin/env python2
# vim:fileencoding=utf-8 # vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import, # License: GPLv3 Copyright: 2015, Kovid Goyal <kovid at kovidgoyal.net>
print_function) from __future__ import absolute_import, division, print_function, unicode_literals
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import json import json
from functools import partial
from importlib import import_module from importlib import import_module
from threading import Lock from threading import Lock
from calibre.srv.auth import AuthController from calibre.srv.auth import AuthController
from calibre.srv.errors import HTTPForbidden
from calibre.srv.library_broker import LibraryBroker
from calibre.srv.routes import Router from calibre.srv.routes import Router
from calibre.srv.users import UserManager from calibre.srv.users import UserManager
from calibre.srv.library_broker import LibraryBroker
from calibre.utils.date import utcnow from calibre.utils.date import utcnow
@ -52,12 +51,25 @@ class Context(object):
pass pass
def get_library(self, data, library_id=None): def get_library(self, data, library_id=None):
# TODO: Restrict the libraries based on data.username if not data.username:
return self.library_broker.get(library_id) return self.library_broker.get(library_id)
lf = partial(self.user_manager.allowed_library_names, data.username)
allowed_libraries = self.library_broker.allowed_libraries(lf)
if not allowed_libraries:
raise HTTPForbidden('The user {} is not allowed to access any libraries on this server'.format(data.username))
library_id = library_id or next(allowed_libraries.iterkeys())
if library_id in allowed_libraries:
return self.library_broker.get(library_id)
raise HTTPForbidden('The user {} is not allowed to access the library {}'.format(data.username, library_id))
def library_info(self, data): def library_info(self, data):
# TODO: Restrict the libraries based on data.username if not data.username:
return self.library_broker.library_map, self.library_broker.default_library return self.library_broker.library_map, self.library_broker.default_library
lf = partial(self.user_manager.allowed_library_names, data.username)
allowed_libraries = self.library_broker.allowed_libraries(lf)
if not allowed_libraries:
raise HTTPForbidden('The user {} is not allowed to access any libraries on this server'.format(data.username))
return dict(allowed_libraries), next(allowed_libraries.iterkeys())
def allowed_book_ids(self, data, db): def allowed_book_ids(self, data, db):
with self.lock: with self.lock:

View File

@ -104,7 +104,13 @@ class LibraryBroker(object):
@property @property
def library_map(self): def library_map(self):
return {k: os.path.basename(v) for k, v in self.lmap.iteritems()} with self:
return {k: os.path.basename(v) for k, v in self.lmap.iteritems()}
def allowed_libraries(self, filter_func):
with self:
allowed_names = filter_func(os.path.basename(l) for l in self.lmap.itervalues())
return OrderedDict(((lid, path) for lid, path in self.lmap.iteritems() if os.path.basename(path) in allowed_names))
def __enter__(self): def __enter__(self):
self.lock.acquire() self.lock.acquire()

View File

@ -7,11 +7,14 @@ __license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import httplib, base64, urllib2, subprocess, os, cookielib import httplib, base64, urllib2, subprocess, os, cookielib
from collections import namedtuple
try: try:
from distutils.spawn import find_executable from distutils.spawn import find_executable
except ImportError: # windows except ImportError: # windows
find_executable = lambda x: None find_executable = lambda x: None
from calibre.ptempfile import TemporaryDirectory
from calibre.srv.errors import HTTPForbidden
from calibre.srv.tests.base import BaseTest, TestServer from calibre.srv.tests.base import BaseTest, TestServer
from calibre.srv.routes import endpoint, Router from calibre.srv.routes import endpoint, Router
@ -116,6 +119,43 @@ class TestAuth(BaseTest):
self.ae((httplib.UNAUTHORIZED, b''), request('asf', 'testpw')) self.ae((httplib.UNAUTHORIZED, b''), request('asf', 'testpw'))
# }}} # }}}
def test_library_restrictions(self): # {{{
from calibre.srv.opts import Options
from calibre.srv.handler import Handler
from calibre.db.legacy import create_backend
opts = Options(userdb=':memory:')
Data = namedtuple('Data', 'username')
with TemporaryDirectory() as base:
l1, l2, l3 = map(lambda x: os.path.join(base, 'l' + x), '123')
for l in (l1, l2, l3):
create_backend(l).close()
ctx = Handler((l1, l2, l3), opts).router.ctx
um = ctx.user_manager
def get_library(username=None, library_id=None):
ans = ctx.get_library(Data(username), library_id=library_id)
return os.path.basename(ans.backend.library_path)
def library_info(username=None):
lmap, defaultlib = ctx.library_info(Data(username))
lmap = {k:os.path.basename(v) for k, v in lmap.iteritems()}
return lmap, defaultlib
self.assertEqual(get_library(), 'l1')
self.assertEqual(library_info()[0], {'l%d'%i:'l%d'%i for i in range(1, 4)})
self.assertEqual(library_info()[1], 'l1')
self.assertRaises(HTTPForbidden, get_library, 'xxx')
um.add_user('a', 'a')
self.assertEqual(library_info('a')[0], {'l%d'%i:'l%d'%i for i in range(1, 4)})
um.update_user_restrictions('a', {'blocked_library_names': ['l2']})
self.assertEqual(library_info('a')[0], {'l%d'%i:'l%d'%i for i in range(1, 4) if i != 2})
um.update_user_restrictions('a', {'allowed_library_names': ['l3']})
self.assertEqual(library_info('a')[0], {'l%d'%i:'l%d'%i for i in range(1, 4) if i == 3})
self.assertEqual(library_info('a')[1], 'l3')
self.assertRaises(HTTPForbidden, get_library, 'a', 'l1')
# }}}
def test_digest_auth(self): # {{{ def test_digest_auth(self): # {{{
'Test HTTP Digest auth' 'Test HTTP Digest auth'
from calibre.srv.http_request import normalize_header_name from calibre.srv.http_request import normalize_header_name

View File

@ -61,6 +61,7 @@ class UserManager(object):
def __init__(self, path=None): def __init__(self, path=None):
self.path = os.path.join(config_dir, 'server-users.sqlite') if path is None else path self.path = os.path.join(config_dir, 'server-users.sqlite') if path is None else path
self._conn = None self._conn = None
self._restrictions = {}
def get_session_data(self, username): def get_session_data(self, username):
with self.lock: with self.lock:
@ -100,14 +101,17 @@ class UserManager(object):
except ValueError: except ValueError:
return _('The password must contain only ASCII (English) characters and symbols') return _('The password must contain only ASCII (English) characters and symbols')
def add_user(self, username, pw, restriction='', readonly=False): def add_user(self, username, pw, restriction=None, readonly=False):
with self.lock: with self.lock:
msg = self.validate_username(username) or self.validate_password(pw) msg = self.validate_username(username) or self.validate_password(pw)
if msg is not None: if msg is not None:
raise ValueError(msg) raise ValueError(msg)
restriction = restriction or {}
if not isinstance(restriction, dict):
raise TypeError('restriction must be a dict')
self.conn.cursor().execute( self.conn.cursor().execute(
'INSERT INTO users (name, pw, restriction, readonly) VALUES (?, ?, ?, ?)', 'INSERT INTO users (name, pw, restriction, readonly) VALUES (?, ?, ?, ?)',
(username, pw, restriction, ('y' if readonly else 'n'))) (username, pw, json.dumps(restriction), ('y' if readonly else 'n')))
def remove_user(self, username): def remove_user(self, username):
with self.lock: with self.lock:
@ -127,3 +131,36 @@ class UserManager(object):
raise ValueError(msg) raise ValueError(msg)
self.conn.cursor().execute( self.conn.cursor().execute(
'UPDATE users SET pw=? WHERE name=?', (pw, username)) 'UPDATE users SET pw=? WHERE name=?', (pw, username))
def restrictions(self, username):
with self.lock:
r = self._restrictions.get(username)
if r is None:
for restriction, in self.conn.cursor().execute(
'SELECT restriction FROM users WHERE name=?', (username,)):
self._restrictions[username] = r = json.loads(restriction)
r['allowed_library_names'] = frozenset(r.get('allowed_library_names', ()))
r['blocked_library_names'] = frozenset(r.get('blocked_library_names', ()))
break
return r
def allowed_library_names(self, username, all_library_names):
' Get allowed library names for specified user from set of all library names '
r = self.restrictions(username)
if r is None:
return set()
inc = r['allowed_library_names']
exc = r['blocked_library_names']
def check(n):
n = n.lower()
return (not inc or n in inc) and n not in exc
return {n for n in all_library_names if check(n)}
def update_user_restrictions(self, username, restrictions):
if not isinstance(restrictions, dict):
raise TypeError('restrictions must be a dict')
with self.lock:
self._restrictions.pop(username, None)
self.conn.cursor().execute(
'UPDATE users SET restriction=? WHERE name=?', (json.dumps(restrictions), username))