Improvements to template parser:

- performance improvement: build and save an interpretable tree instead of parsing and executing at the same time
- improve syntax error detection
- performance improvement: automatically convert "test" to "if then else", avoiding executing the unused branch
- inline raw_field (field was already inlined)
This commit is contained in:
Charles Haley 2020-09-13 12:19:18 +01:00
parent 28ef780d99
commit 1642ac6af2

View File

@ -17,6 +17,70 @@ from calibre.utils.formatter_functions import formatter_functions
from calibre.utils.icu import strcmp
from polyglot.builtins import unicode_type, error_message
class Node(object):
NODE_RVALUE = 1
NODE_IF = 2
NODE_ASSIGN = 3
NODE_FUNC = 4
NODE_INFIX = 5
NODE_CONSTANT = 6
NODE_FIELD = 7
NODE_RAW_FIELD = 8
class IfNode(Node):
def __init__(self, condition, then_part, else_part):
Node.__init__(self)
self.node_type = self.NODE_IF
self.condition = condition
self.then_part = then_part
self.else_part = else_part
class AssignNode(Node):
def __init__(self, left, right):
Node.__init__(self)
self.node_type = self.NODE_ASSIGN
self.left = left
self.right = right
class FunctionNode(Node):
def __init__(self, function_name, expression_list):
Node.__init__(self)
self.node_type = self.NODE_FUNC
self.name = function_name
self.expression_list = expression_list
class InfixNode(Node):
def __init__(self, operator, left, right):
Node.__init__(self)
self.node_type = self.NODE_INFIX
self.operator = operator
self.left = left
self.right = right
class ConstantNode(Node):
def __init__(self, value):
Node.__init__(self)
self.node_type = self.NODE_CONSTANT
self.value = value
class VariableNode(Node):
def __init__(self, name):
Node.__init__(self)
self.node_type = self.NODE_RVALUE
self.name = name
class FieldNode(Node):
def __init__(self, expression):
Node.__init__(self)
self.node_type = self.NODE_FIELD
self.expression = expression
class RawFieldNode(Node):
def __init__(self, expression):
Node.__init__(self)
self.node_type = self.NODE_RAW_FIELD
self.expression = expression
class _Parser(object):
LEX_OP = 1
@ -24,22 +88,7 @@ class _Parser(object):
LEX_CONST = 3
LEX_EOF = 4
LEX_INFIX = 5
LEX_IF = 6
LEX_THEN = 7
LEX_ELSE = 8
LEX_FI = 9
def __init__(self, val, prog, funcs, parent):
self.lex_pos = 0
self.prog = prog[0]
self.prog_len = len(self.prog)
if prog[1] != '':
self.error(_('Failed to scan program. Invalid input {0}').format(prog[1]))
self.parent = parent
self.parent_kwargs = parent.kwargs
self.parent_book = parent.book
self.locals = {'$':val}
self.funcs = funcs
LEX_KEYWORD = 6
def error(self, message):
try:
@ -115,25 +164,29 @@ class _Parser(object):
def token_is_if(self):
try:
return self.prog[self.lex_pos][0] == self.LEX_IF
token = self.prog[self.lex_pos]
return token[1] == 'if' and token[0] == self.LEX_KEYWORD
except:
return False
def token_is_then(self):
try:
return self.prog[self.lex_pos][0] == self.LEX_THEN
token = self.prog[self.lex_pos]
return token[1] == 'then' and token[0] == self.LEX_KEYWORD
except:
return False
def token_is_else(self):
try:
return self.prog[self.lex_pos][0] == self.LEX_ELSE
token = self.prog[self.lex_pos]
return token[1] == 'else' and token[0] == self.LEX_KEYWORD
except:
return False
def token_is_fi(self):
try:
return self.prog[self.lex_pos][0] == self.LEX_FI
token = self.prog[self.lex_pos]
return token[1] == 'fi' and token[0] == self.LEX_KEYWORD
except:
return False
@ -149,60 +202,118 @@ class _Parser(object):
except:
return True
def program(self):
val = self.statement()
def program(self, funcs, prog):
self.lex_pos = 0
self.funcs = funcs
self.prog = prog[0]
self.prog_len = len(self.prog)
if prog[1] != '':
self.error(_('Failed to scan program. Invalid input {0}').format(prog[1]))
tree = self.expression_list()
if not self.token_is_eof():
self.error(_('Syntax error - program ends before EOF'))
return val
return tree
def statement(self):
val = ''
def expression_list(self):
expr_list = []
while not self.token_is_eof():
val = self.infix_expr()
expr_list.append(self.infix_expr())
if not self.token_op_is_semicolon():
break
self.consume()
return val
def consume_if(self):
self.consume()
while not self.token_is_fi():
if self.token_is_if():
self.consume_if()
self.consume()
def consume_then_branch(self):
while not (self.token_is_eof() or self.token_is_fi() or self.token_is_else()):
if self.token_is_if():
self.consume_if()
self.consume()
def consume_else_branch(self):
while not (self.token_is_eof() or self.token_is_fi()):
if self.token_is_if():
self.consume_if()
self.consume()
return expr_list
def if_expression(self):
self.consume()
val = ''
test_part = self.infix_expr()
condition = self.infix_expr()
if not self.token_is_then():
self.error(_("Missing 'then' in if statement"))
if test_part:
self.consume()
then_part = self.expression_list()
if self.token_is_else():
self.consume()
val = self.statement()
if not (self.token_is_else() or self.token_is_fi()):
self.error(_("Missing 'else' or 'fi' in if statement"))
self.consume_else_branch()
else_part = self.expression_list()
else:
self.consume_then_branch()
if self.token_is_else():
self.consume()
val = self.statement()
else_part = None
if not self.token_is_fi():
self.error(_("Missing 'fi' in if statement"))
self.consume()
return IfNode(condition, then_part, else_part)
def infix_expr(self):
left = self.expr()
if not self.token_op_is_infix_compare():
return left
operator = self.token()
return InfixNode(operator, left, self.expr())
def expr(self):
if self.token_is_if():
return self.if_expression()
if self.token_is_id():
# We have an identifier. Determine if it is a function
id_ = self.token()
if not self.token_op_is_lparen():
if self.token_op_is_equals():
# classic assignment statement
self.consume()
return AssignNode(id_, self.infix_expr())
return VariableNode(id_)
# We have a function.
# Check if it is a known one. We do this here so error reporting is
# better, as it can identify the tokens near the problem.
id_ = id_.strip()
if id_ not in self.funcs:
self.error(_('Unknown function {0}').format(id_))
# Eat the paren
self.consume()
arguments = list()
while not self.token_op_is_rparen():
# evaluate the expression (recursive call)
arguments.append(self.infix_expr())
if not self.token_op_is_comma():
break
self.consume()
if self.token() != ')':
self.error(_('Missing closing parenthesis'))
if id_ == 'field' and len(arguments) == 1:
return FieldNode(arguments[0])
if id_ == 'raw_field' and len(arguments) == 1:
return RawFieldNode(arguments[0])
if id_ == 'test' and len(arguments) == 3:
return IfNode(arguments[0], (arguments[1],), (arguments[2],))
if (id_ == 'assign' and len(arguments) == 2
and arguments[0].node_type == Node.NODE_RVALUE):
return AssignNode(arguments[0].name, arguments[1])
cls = self.funcs[id_]
if cls.arg_count != -1 and len(arguments) != cls.arg_count:
self.error(_('Incorrect number of expression_list for function {0}').format(id_))
return FunctionNode(id_, arguments)
elif self.token_is_constant():
# String or number
return ConstantNode(self.token())
else:
self.error(_('Expression is not function or constant'))
class _Interpreter(object):
def error(self, message):
m = 'Interpreter: ' + message
raise ValueError(m)
def program(self, funcs, parent, prog, val):
self.parent = parent
self.parent_kwargs = parent.kwargs
self.parent_book = parent.book
self.funcs = funcs
self.locals = {'$':val}
return self.expression_list(prog)
def expression_list(self, prog):
val = ''
for p in prog:
val = self.expr(p)
return val
INFIX_OPS = {
@ -220,74 +331,93 @@ class _Parser(object):
">=#": lambda x, y: float(x) >= float(y) if x and y else False,
}
def infix_expr(self):
left = self.expr()
if self.token_op_is_infix_compare():
t = self.token()
right = self.expr()
return '1' if self.INFIX_OPS[t](left, right) else ''
return left
def do_node_infix(self, prog):
left = self.expr(prog.left)
right = self.expr(prog.right)
return '1' if self.INFIX_OPS[prog.operator](left, right) else ''
def expr(self):
if self.token_is_if():
return self.if_expression()
if self.token_is_id():
# We have an identifier. Determine if it is a function
id_ = self.token()
if not self.token_op_is_lparen():
if self.token_op_is_equals():
# classic assignment statement
self.consume()
cls = self.funcs['assign']
return cls.eval_(self.parent, self.parent_kwargs,
self.parent_book, self.locals, id_, self.infix_expr())
val = self.locals.get(id_, None)
if val is None:
self.error(_('Unknown identifier {0}').format(id_))
return val
# We have a function.
# Check if it is a known one. We do this here so error reporting is
# better, as it can identify the tokens near the problem.
id_ = id_.strip()
if id_ not in self.funcs:
self.error(_('Unknown function {0}').format(id_))
def do_node_if(self, prog):
test_part = self.expr(prog.condition)
if test_part:
return self.expression_list(prog.then_part)
elif prog.else_part:
return self.expression_list(prog.else_part)
return ''
# Eat the paren
self.consume()
args = list()
while not self.token_op_is_rparen():
if id_ == 'assign' and len(args) == 0:
# Must handle the lvalue semantics of the assign function.
# The first argument is the name of the destination, not
# the value.
if not self.token_is_id():
self.error(_("'Assign' requires the first parameter be an id"))
args.append(self.token())
else:
# evaluate the argument (recursive call)
args.append(self.infix_expr())
if not self.token_op_is_comma():
break
self.consume()
if self.token() != ')':
self.error(_('Missing closing parenthesis'))
def do_node_rvalue(self, prog):
try:
return self.locals[prog.name]
except:
self.error(_('Unknown identifier {0}').format(prog.name))
# Evaluate the function.
if id_ == 'field':
# Evaluate the 'field' function inline for performance
if len(args) != 1:
self.error(_('Incorrect number of arguments for function {0}').format(id_))
return self.parent.get_value(args[0], [], self.parent_kwargs)
cls = self.funcs[id_]
if cls.arg_count != -1 and len(args) != cls.arg_count:
self.error(_('Incorrect number of arguments for function {0}').format(id_))
return cls.eval_(self.parent, self.parent_kwargs,
self.parent_book, self.locals, *args)
elif self.token_is_constant():
# String or number
return self.token()
else:
self.error(_('Expression is not function or constant'))
def do_node_func(self, prog):
args = list()
for arg in prog.expression_list:
# evaluate the expression (recursive call)
args.append(self.expr(arg))
# Evaluate the function.
id_ = prog.name.strip()
cls = self.funcs[id_]
return cls.eval_(self.parent, self.parent_kwargs,
self.parent_book, self.locals, *args)
def do_node_constant(self, prog):
return prog.value
def do_node_field(self, prog):
try:
name = self.expr(prog.expression)
try:
return self.parent.get_value(name, [], self.parent_kwargs)
except:
self.error(_('Unknown field {0}').format(name))
except ValueError as e:
raise e
except:
self.error(_('Unknown field {0}').format('parse error'))
def do_node_raw_field(self, prog):
try:
name = self.expr(prog.expression)
res = getattr(self.parent_book, name, None)
if res is None:
self.error(_('Unknown field {0}').format(name))
if isinstance(res, list):
fm = self.parent_book.metadata_for_field(name)
if fm is None:
return ', '.join(res)
return fm['is_multiple']['list_to_ui'].join(res)
return unicode_type(res)
except ValueError as e:
raise e
except:
self.error(_('Unknown field {0}').format('parse error'))
def do_node_assign(self, prog):
t = self.expr(prog.right)
self.locals[prog.left] = t
return t
NODE_OPS = {
Node.NODE_IF: do_node_if,
Node.NODE_ASSIGN: do_node_assign,
Node.NODE_CONSTANT: do_node_constant,
Node.NODE_RVALUE: do_node_rvalue,
Node.NODE_FUNC: do_node_func,
Node.NODE_FIELD: do_node_field,
Node.NODE_RAW_FIELD:do_node_raw_field,
Node.NODE_INFIX: do_node_infix,
}
def expr(self, prog):
try:
return self.NODE_OPS[prog.node_type](self, prog)
except ValueError as e:
raise e
except:
if (DEBUG):
traceback.print_exc()
self.error(_('Internal error evaluating an expression'))
class TemplateFormatter(string.Formatter):
@ -308,6 +438,8 @@ class TemplateFormatter(string.Formatter):
self.strip_results = True
self.locals = {}
self.funcs = formatter_functions().get_functions()
self.gpm_parser = _Parser()
self.gpm_interpreter = _Interpreter()
def _do_format(self, val, fmt):
if not fmt or not val:
@ -356,17 +488,14 @@ class TemplateFormatter(string.Formatter):
lex_scanner = re.Scanner([
(r'(==#|!=#|<=#|<#|>=#|>#|==|!=|<=|<|>=|>)',
lambda x,t: (_Parser.LEX_INFIX, t)),
(r'if\b', lambda x,t: (_Parser.LEX_IF, t)), # noqa
(r'then\b', lambda x,t: (_Parser.LEX_THEN, t)), # noqa
(r'else\b', lambda x,t: (_Parser.LEX_ELSE, t)), # noqa
(r'fi\b', lambda x,t: (_Parser.LEX_FI, t)), # noqa
(r'[(),=;]', lambda x,t: (_Parser.LEX_OP, t)), # noqa
(r'-?[\d\.]+', lambda x,t: (_Parser.LEX_CONST, t)), # noqa
(r'\$', lambda x,t: (_Parser.LEX_ID, t)), # noqa
(r'\w+', lambda x,t: (_Parser.LEX_ID, t)), # noqa
(r'".*?((?<!\\)")', lambda x,t: (_Parser.LEX_CONST, t[1:-1])), # noqa
(r'\'.*?((?<!\\)\')', lambda x,t: (_Parser.LEX_CONST, t[1:-1])), # noqa
lambda x,t: (_Parser.LEX_INFIX, t)),
(r'(if|then|else|fi)\b', lambda x,t: (_Parser.LEX_KEYWORD, t)), # noqa
(r'[(),=;]', lambda x,t: (_Parser.LEX_OP, t)), # noqa
(r'-?[\d\.]+', lambda x,t: (_Parser.LEX_CONST, t)), # noqa
(r'\$', lambda x,t: (_Parser.LEX_ID, t)), # noqa
(r'\w+', lambda x,t: (_Parser.LEX_ID, t)), # noqa
(r'".*?((?<!\\)")', lambda x,t: (_Parser.LEX_CONST, t[1:-1])), # noqa
(r'\'.*?((?<!\\)\')', lambda x,t: (_Parser.LEX_CONST, t[1:-1])), # noqa
(r'\n#.*?(?:(?=\n)|$)', None),
(r'\s', None),
], flags=re.DOTALL)
@ -376,14 +505,13 @@ class TemplateFormatter(string.Formatter):
# is much more expensive than the cache lookup. This is certainly true
# for more than a few tokens, but it isn't clear for simple programs.
if column_name is not None and self.template_cache is not None:
lprog = self.template_cache.get(column_name, None)
if not lprog:
lprog = self.lex_scanner.scan(prog)
self.template_cache[column_name] = lprog
tree = self.template_cache.get(column_name, None)
if not tree:
tree = self.gpm_parser.program(self.funcs, self.lex_scanner.scan(prog))
self.template_cache[column_name] = tree
else:
lprog = self.lex_scanner.scan(prog)
parser = _Parser(val, lprog, self.funcs, self)
return parser.program()
tree = self.gpm_parser.program(self.funcs, self.lex_scanner.scan(prog))
return self.gpm_interpreter.program(self.funcs, self, tree, val)
# ################# Override parent classes methods #####################
@ -440,7 +568,7 @@ class TemplateFormatter(string.Formatter):
args = [self.backslash_comma_to_comma.sub(',', a) for a in args]
if (func.arg_count == 1 and (len(args) != 1 or args[0])) or \
(func.arg_count > 1 and func.arg_count != len(args)+1):
raise ValueError('Incorrect number of arguments for function '+ fmt[0:p])
raise ValueError('Incorrect number of expression_list for function '+ fmt[0:p])
if func.arg_count == 1:
val = func.eval_(self, self.kwargs, self.book, self.locals, val)
if self.strip_results: