Adding inner rewrite for expression rewriter

This commit is contained in:
Zoe Roux 2020-09-13 15:22:51 +02:00
parent 8d10b44b1d
commit c42805e415
5 changed files with 91 additions and 8 deletions

View File

@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
@ -18,6 +20,14 @@ namespace Kyoo
public class ExpressionRewrite : ExpressionVisitor
{
private string _inner;
private readonly List<(string inner, ParameterExpression param, ParameterExpression newParam)> _innerRewrites;
private ExpressionRewrite()
{
_innerRewrites = new List<(string, ParameterExpression, ParameterExpression)>();
}
public static Expression Rewrite(Expression expression)
{
return new ExpressionRewrite().Visit(expression);
@ -30,17 +40,81 @@ namespace Kyoo
protected override Expression VisitMember(MemberExpression node)
{
ExpressionRewriteAttribute attr = node.Member.GetCustomAttribute<ExpressionRewriteAttribute>();
(string inner, _, ParameterExpression p) = _innerRewrites.FirstOrDefault(x => x.param == node.Expression);
if (inner != null)
{
Expression param = p;
foreach (string accessor in inner.Split('.'))
param = Expression.Property(param, accessor);
node = Expression.Property(param, node.Member.Name);
}
// Can't use node.Member directly because we want to support attribute override
MemberInfo member = node.Expression.Type.GetProperty(node.Member.Name) ?? node.Member;
ExpressionRewriteAttribute attr = member!.GetCustomAttribute<ExpressionRewriteAttribute>();
if (attr == null)
return base.VisitMember(node);
Expression property = node.Expression;
foreach (string child in attr.Link.Split('.'))
property = Expression.Property(property, child);
if (property is MemberExpression member)
Visit(member.Expression);
if (property is MemberExpression expr)
Visit(expr.Expression);
_inner = attr.Inner;
return property;
}
protected override Expression VisitLambda<T>(Expression<T> node)
{
(_, ParameterExpression oldParam, ParameterExpression param) = _innerRewrites
.FirstOrDefault(x => node.Parameters.Any(y => y == x.param));
if (param == null)
return base.VisitLambda(node);
ParameterExpression[] newParams = node.Parameters.Where(x => x != oldParam).Append(param).ToArray();
return Expression.Lambda(Visit(node.Body)!, newParams);
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
int count = node.Arguments.Count;
if (node.Object != null)
count++;
if (count != 2)
return base.VisitMethodCall(node);
Expression instance = node.Object ?? node.Arguments.First();
Expression argument = node.Object != null
? node.Arguments.First()
: node.Arguments[1];
Type oldType = instance.Type;
instance = Visit(instance);
if (instance!.Type == oldType)
return base.VisitMethodCall(node);
if (_inner != null && argument is LambdaExpression lambda)
{
// TODO this type handler will usually work with IEnumerable & others but won't work with everything.
Type type = oldType.GetGenericArguments().First();
ParameterExpression oldParam = lambda.Parameters.FirstOrDefault(x => x.Type == type);
if (oldParam != null)
{
Type newType = instance.Type.GetGenericArguments().First();
ParameterExpression newParam = Expression.Parameter(newType, oldParam.Name);
_innerRewrites.Add((_inner, oldParam, newParam));
}
}
argument = Visit(argument);
// TODO this method handler may not work for some methods (ex: method taking a Fun<> method won't have good generic arguments)
MethodInfo method = node.Method.IsGenericMethod
? node.Method.GetGenericMethodDefinition().MakeGenericMethod(instance.Type.GetGenericArguments())
: node.Method;
return node.Object != null
? Expression.Call(instance, method!, argument)
: Expression.Call(null, method!, instance, argument!);
}
}
}

View File

@ -309,12 +309,14 @@ namespace Kyoo
public static Expression<T> Convert<T>([CanBeNull] this Expression expr)
where T : Delegate
{
return expr switch
Expression<T> e = expr switch
{
null => null,
LambdaExpression lambda => new ExpressionConverter<T>(lambda).VisitAndConvert(),
_ => throw new ArgumentException("Can't convert a non lambda.")
};
return ExpressionRewrite.Rewrite<T>(e);
}
private class ExpressionConverter<TTo> : ExpressionVisitor

View File

@ -24,9 +24,14 @@ namespace Kyoo.CommonApi
Expression<Func<T, bool>> defaultWhere = null)
{
if (where == null || where.Count == 0)
return defaultWhere;
ParameterExpression param = Expression.Parameter(typeof(T));
{
if (defaultWhere == null)
return null;
Expression body = ExpressionRewrite.Rewrite(defaultWhere.Body);
return Expression.Lambda<Func<T, bool>>(body, defaultWhere.Parameters.First());
}
ParameterExpression param = defaultWhere?.Parameters.First() ?? Expression.Parameter(typeof(T));
Expression expression = defaultWhere?.Body;
foreach ((string key, string desired) in where)

View File

@ -9,6 +9,7 @@ namespace Kyoo.Models
{
[JsonIgnore] [NotMergable] public virtual IEnumerable<GenreLink> Links { get; set; }
[ExpressionRewrite(nameof(Links), nameof(GenreLink.Genre))]
[JsonIgnore] [NotMergable] public override IEnumerable<Show> Shows
{
get => Links?.Select(x => x.Show);

View File

@ -11,6 +11,7 @@ namespace Kyoo.Models
public class ShowDE : Show
{
[JsonIgnore] [NotMergable] public virtual IEnumerable<GenreLink> GenreLinks { get; set; }
[ExpressionRewrite(nameof(GenreLinks), nameof(GenreLink.Genre))]
public override IEnumerable<Genre> Genres
{
get => GenreLinks?.Select(x => x.Genre);