diff --git a/Kyoo.Common/ExpressionRewrite.cs b/Kyoo.Common/ExpressionRewrite.cs index fadc26b2..c7db9636 100644 --- a/Kyoo.Common/ExpressionRewrite.cs +++ b/Kyoo.Common/ExpressionRewrite.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -18,6 +20,14 @@ namespace Kyoo public class ExpressionRewrite : ExpressionVisitor { + private string _inner; + private readonly List<(string inner, ParameterExpression param, ParameterExpression newParam)> _innerRewrites; + + private ExpressionRewrite() + { + _innerRewrites = new List<(string, ParameterExpression, ParameterExpression)>(); + } + public static Expression Rewrite(Expression expression) { return new ExpressionRewrite().Visit(expression); @@ -30,17 +40,81 @@ namespace Kyoo protected override Expression VisitMember(MemberExpression node) { - ExpressionRewriteAttribute attr = node.Member.GetCustomAttribute(); + (string inner, _, ParameterExpression p) = _innerRewrites.FirstOrDefault(x => x.param == node.Expression); + if (inner != null) + { + Expression param = p; + foreach (string accessor in inner.Split('.')) + param = Expression.Property(param, accessor); + node = Expression.Property(param, node.Member.Name); + } + + // Can't use node.Member directly because we want to support attribute override + MemberInfo member = node.Expression.Type.GetProperty(node.Member.Name) ?? node.Member; + ExpressionRewriteAttribute attr = member!.GetCustomAttribute(); if (attr == null) return base.VisitMember(node); - + Expression property = node.Expression; foreach (string child in attr.Link.Split('.')) property = Expression.Property(property, child); - if (property is MemberExpression member) - Visit(member.Expression); + if (property is MemberExpression expr) + Visit(expr.Expression); + _inner = attr.Inner; return property; } + + protected override Expression VisitLambda(Expression node) + { + (_, ParameterExpression oldParam, ParameterExpression param) = _innerRewrites + .FirstOrDefault(x => node.Parameters.Any(y => y == x.param)); + if (param == null) + return base.VisitLambda(node); + + ParameterExpression[] newParams = node.Parameters.Where(x => x != oldParam).Append(param).ToArray(); + return Expression.Lambda(Visit(node.Body)!, newParams); + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + int count = node.Arguments.Count; + if (node.Object != null) + count++; + if (count != 2) + return base.VisitMethodCall(node); + + Expression instance = node.Object ?? node.Arguments.First(); + Expression argument = node.Object != null + ? node.Arguments.First() + : node.Arguments[1]; + + Type oldType = instance.Type; + instance = Visit(instance); + if (instance!.Type == oldType) + return base.VisitMethodCall(node); + + if (_inner != null && argument is LambdaExpression lambda) + { + // TODO this type handler will usually work with IEnumerable & others but won't work with everything. + Type type = oldType.GetGenericArguments().First(); + ParameterExpression oldParam = lambda.Parameters.FirstOrDefault(x => x.Type == type); + if (oldParam != null) + { + Type newType = instance.Type.GetGenericArguments().First(); + ParameterExpression newParam = Expression.Parameter(newType, oldParam.Name); + _innerRewrites.Add((_inner, oldParam, newParam)); + } + } + argument = Visit(argument); + + // TODO this method handler may not work for some methods (ex: method taking a Fun<> method won't have good generic arguments) + MethodInfo method = node.Method.IsGenericMethod + ? node.Method.GetGenericMethodDefinition().MakeGenericMethod(instance.Type.GetGenericArguments()) + : node.Method; + return node.Object != null + ? Expression.Call(instance, method!, argument) + : Expression.Call(null, method!, instance, argument!); + } } } \ No newline at end of file diff --git a/Kyoo.Common/Utility.cs b/Kyoo.Common/Utility.cs index e1fbbad8..2922acec 100644 --- a/Kyoo.Common/Utility.cs +++ b/Kyoo.Common/Utility.cs @@ -309,12 +309,14 @@ namespace Kyoo public static Expression Convert([CanBeNull] this Expression expr) where T : Delegate { - return expr switch + Expression e = expr switch { null => null, LambdaExpression lambda => new ExpressionConverter(lambda).VisitAndConvert(), _ => throw new ArgumentException("Can't convert a non lambda.") }; + + return ExpressionRewrite.Rewrite(e); } private class ExpressionConverter : ExpressionVisitor diff --git a/Kyoo.CommonAPI/ApiHelper.cs b/Kyoo.CommonAPI/ApiHelper.cs index 55b43878..9ae0067a 100644 --- a/Kyoo.CommonAPI/ApiHelper.cs +++ b/Kyoo.CommonAPI/ApiHelper.cs @@ -24,9 +24,14 @@ namespace Kyoo.CommonApi Expression> defaultWhere = null) { if (where == null || where.Count == 0) - return defaultWhere; - - ParameterExpression param = Expression.Parameter(typeof(T)); + { + if (defaultWhere == null) + return null; + Expression body = ExpressionRewrite.Rewrite(defaultWhere.Body); + return Expression.Lambda>(body, defaultWhere.Parameters.First()); + } + + ParameterExpression param = defaultWhere?.Parameters.First() ?? Expression.Parameter(typeof(T)); Expression expression = defaultWhere?.Body; foreach ((string key, string desired) in where) diff --git a/Kyoo/Models/Resources/GenreDE.cs b/Kyoo/Models/Resources/GenreDE.cs index aceafb05..3d5172ab 100644 --- a/Kyoo/Models/Resources/GenreDE.cs +++ b/Kyoo/Models/Resources/GenreDE.cs @@ -9,6 +9,7 @@ namespace Kyoo.Models { [JsonIgnore] [NotMergable] public virtual IEnumerable Links { get; set; } + [ExpressionRewrite(nameof(Links), nameof(GenreLink.Genre))] [JsonIgnore] [NotMergable] public override IEnumerable Shows { get => Links?.Select(x => x.Show); diff --git a/Kyoo/Models/Resources/ShowDE.cs b/Kyoo/Models/Resources/ShowDE.cs index f0b0868a..7d43f25c 100644 --- a/Kyoo/Models/Resources/ShowDE.cs +++ b/Kyoo/Models/Resources/ShowDE.cs @@ -11,6 +11,7 @@ namespace Kyoo.Models public class ShowDE : Show { [JsonIgnore] [NotMergable] public virtual IEnumerable GenreLinks { get; set; } + [ExpressionRewrite(nameof(GenreLinks), nameof(GenreLink.Genre))] public override IEnumerable Genres { get => GenreLinks?.Select(x => x.Genre);