using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Threading.Tasks; using JetBrains.Annotations; using Kyoo.CommonApi; using Kyoo.Models; using Kyoo.Models.Attributes; using Kyoo.Models.Exceptions; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Metadata; namespace Kyoo.Controllers { public abstract class LocalRepository where T : class, IResource { protected readonly DbContext Database; protected abstract Expression> DefaultSort { get; } protected LocalRepository(DbContext database) { Database = database; } public virtual void Dispose() { Database.Dispose(); } public virtual ValueTask DisposeAsync() { return Database.DisposeAsync(); } public virtual Task Get(int id) { return Database.Set().FirstOrDefaultAsync(x => x.ID == id); } public virtual Task Get(string slug) { return Database.Set().FirstOrDefaultAsync(x => x.Slug == slug); } public virtual Task Get(Expression> predicate) { return Database.Set().FirstOrDefaultAsync(predicate); } public virtual Task> GetAll(Expression> where = null, Sort sort = default, Pagination limit = default) { return ApplyFilters(Database.Set(), where, sort, limit); } protected Task> ApplyFilters(IQueryable query, Expression> where = null, Sort sort = default, Pagination limit = default) { return ApplyFilters(query, Get, DefaultSort, where, sort, limit); } protected async Task> ApplyFilters(IQueryable query, Func> get, Expression> defaultSort, Expression> where = null, Sort sort = default, Pagination limit = default) { if (where != null) query = query.Where(where); Expression> sortKey = sort.Key ?? defaultSort; Expression sortExpression = sortKey.Body.NodeType == ExpressionType.Convert ? ((UnaryExpression)sortKey.Body).Operand : sortKey.Body; if (typeof(Enum).IsAssignableFrom(sortExpression.Type)) throw new ArgumentException("Invalid sort key."); query = sort.Descendant ? query.OrderByDescending(sortKey) : query.OrderBy(sortKey); if (limit.AfterID != 0) { TValue after = await get(limit.AfterID); Expression key = Expression.Constant(sortKey.Compile()(after), sortExpression.Type); query = query.Where(Expression.Lambda>( ApiHelper.StringCompatibleExpression(Expression.GreaterThan, sortExpression, key), sortKey.Parameters.First() )); } if (limit.Count > 0) query = query.Take(limit.Count); return await query.ToListAsync(); } public virtual Task GetCount(Expression> where = null) { IQueryable query = Database.Set(); if (where != null) query = query.Where(where); return query.CountAsync(); } public virtual async Task Create([NotNull] T obj) { if (obj == null) throw new ArgumentNullException(nameof(obj)); await Validate(obj); return obj; } public virtual async Task CreateIfNotExists(T obj, bool silentFail = false) { try { if (obj == null) throw new ArgumentNullException(nameof(obj)); T old = await Get(obj.Slug); if (old != null) return old; return await Create(obj); } catch (DuplicatedItemException) { T old = await Get(obj!.Slug); if (old == null) throw new SystemException("Unknown database state."); return old; } catch { if (silentFail) return default; throw; } } public virtual async Task Edit(T edited, bool resetOld) { if (edited == null) throw new ArgumentNullException(nameof(edited)); bool lazyLoading = Database.ChangeTracker.LazyLoadingEnabled; Database.ChangeTracker.LazyLoadingEnabled = false; try { T old = await Get(edited.ID); if (old == null) throw new ItemNotFound($"No resource found with the ID {edited.ID}."); foreach (NavigationEntry navigation in Database.Entry(old).Navigations) { if (navigation.Metadata.PropertyInfo.GetCustomAttribute() != null) { if (resetOld) { await navigation.LoadAsync(); continue; } IClrPropertyGetter getter = navigation.Metadata.GetGetter(); if (getter.HasDefaultValue(edited)) continue; await navigation.LoadAsync(); // TODO this may be usless for lists since the API does not return IDs but the // TODO LinkEquality does not check slugs (their are lazy loaded and only the ID is available) if (Utility.ResourceEquals(getter.GetClrValue(edited), getter.GetClrValue(old))) navigation.Metadata.PropertyInfo.SetValue(edited, default); } else navigation.Metadata.PropertyInfo.SetValue(edited, default); } if (resetOld) Utility.Nullify(old); Utility.Complete(old, edited); await Validate(old); await Database.SaveChangesAsync(); return old; } finally { Database.ChangeTracker.LazyLoadingEnabled = lazyLoading; } } protected bool ShouldValidate(T2 value) { return value != null && Database.Entry(value).State == EntityState.Detached; } protected virtual Task Validate(T resource) { if (string.IsNullOrEmpty(resource.Slug)) throw new ArgumentException("Resource can't have null as a slug."); if (int.TryParse(resource.Slug, out int _) && typeof(T).GetCustomAttribute() == null) { try { MethodInfo setter = typeof(T).GetProperty(nameof(resource.Slug))!.GetSetMethod(); if (setter != null) setter.Invoke(resource, new object[] {resource.Slug + '!'}); else throw new ArgumentException("Resources slug can't be number only."); } catch { throw new ArgumentException("Resources slug can't be number only."); } } return Task.CompletedTask; } public virtual async Task Delete(int id) { T resource = await Get(id); await Delete(resource); } public virtual async Task Delete(string slug) { T resource = await Get(slug); await Delete(resource); } public abstract Task Delete(T obj); public virtual async Task DeleteRange(IEnumerable objs) { foreach (T obj in objs) await Delete(obj); } public virtual async Task DeleteRange(IEnumerable ids) { foreach (int id in ids) await Delete(id); } public virtual async Task DeleteRange(IEnumerable slugs) { foreach (string slug in slugs) await Delete(slug); } } public abstract class LocalRepository : LocalRepository, IRepository where T : class, IResource where TInternal : class, T, new() { protected LocalRepository(DbContext database) : base(database) { } public new Task Get(int id) { return base.Get(id).Cast(); } public new Task Get(string slug) { return base.Get(slug).Cast(); } public Task Get(Expression> predicate) { return Get(predicate.Convert>()).Cast(); } public abstract Task> Search(string query); public virtual Task> GetAll(Expression> where = null, Sort sort = default, Pagination limit = default) { return ApplyFilters(Database.Set(), where, sort, limit); } protected virtual async Task> ApplyFilters(IQueryable query, Expression> where = null, Sort sort = default, Pagination limit = default) { ICollection items = await ApplyFilters(query, base.Get, DefaultSort, where.Convert>(), sort.To(), limit); return items.ToList(); } public virtual Task GetCount(Expression> where = null) { IQueryable query = Database.Set(); if (where != null) query = query.Where(where.Convert>()); return query.CountAsync(); } Task IRepository.Create(T item) { if (item == null) throw new ArgumentNullException(nameof(item)); TInternal obj = item as TInternal ?? new TInternal(); if (!(item is TInternal)) Utility.Assign(obj, item); return Create(obj).Cast() .Then(x => item.ID = x.ID); } Task IRepository.CreateIfNotExists(T item, bool silentFail) { if (item == null) throw new ArgumentNullException(nameof(item)); TInternal obj = item as TInternal ?? new TInternal(); if (!(item is TInternal)) Utility.Assign(obj, item); return CreateIfNotExists(obj, silentFail).Cast() .Then(x => item.ID = x.ID); } public Task Edit(T edited, bool resetOld) { if (edited == null) throw new ArgumentNullException(nameof(edited)); if (edited is TInternal intern) return Edit(intern, resetOld).Cast(); TInternal obj = new(); Utility.Assign(obj, edited); return base.Edit(obj, resetOld).Cast(); } public abstract override Task Delete([NotNull] TInternal obj); Task IRepository.Delete(T obj) { if (obj == null) throw new ArgumentNullException(nameof(obj)); if (obj is TInternal intern) return Delete(intern); TInternal item = new(); Utility.Assign(item, obj); return Delete(item); } public virtual async Task DeleteRange(IEnumerable objs) { foreach (T obj in objs) await ((IRepository)this).Delete(obj); } } }