using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Threading.Tasks; using JetBrains.Annotations; using Kyoo.CommonApi; using Kyoo.Models; using Kyoo.Models.Attributes; using Kyoo.Models.Exceptions; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Metadata; namespace Kyoo.Controllers { public abstract class LocalRepository : IRepository where T : class, IResource { protected readonly DbContext Database; protected abstract Expression> DefaultSort { get; } protected LocalRepository(DbContext database) { Database = database; } public virtual void Dispose() { Database.Dispose(); GC.SuppressFinalize(this); } public virtual ValueTask DisposeAsync() { return Database.DisposeAsync(); } public virtual Task Get(int id) { return Database.Set().FirstOrDefaultAsync(x => x.ID == id); } public virtual Task GetWithTracking(int id) { return Database.Set().AsTracking().FirstOrDefaultAsync(x => x.ID == id); } public virtual Task Get(string slug) { return Database.Set().FirstOrDefaultAsync(x => x.Slug == slug); } public virtual Task Get(Expression> predicate) { return Database.Set().FirstOrDefaultAsync(predicate); } public abstract Task> Search(string query); public virtual Task> GetAll(Expression> where = null, Sort sort = default, Pagination limit = default) { return ApplyFilters(Database.Set(), where, sort, limit); } protected Task> ApplyFilters(IQueryable query, Expression> where = null, Sort sort = default, Pagination limit = default) { return ApplyFilters(query, Get, DefaultSort, where, sort, limit); } protected async Task> ApplyFilters(IQueryable query, Func> get, Expression> defaultSort, Expression> where = null, Sort sort = default, Pagination limit = default) { if (where != null) query = query.Where(where); Expression> sortKey = sort.Key ?? defaultSort; Expression sortExpression = sortKey.Body.NodeType == ExpressionType.Convert ? ((UnaryExpression)sortKey.Body).Operand : sortKey.Body; if (typeof(Enum).IsAssignableFrom(sortExpression.Type)) throw new ArgumentException("Invalid sort key."); query = sort.Descendant ? query.OrderByDescending(sortKey) : query.OrderBy(sortKey); if (limit.AfterID != 0) { TValue after = await get(limit.AfterID); Expression key = Expression.Constant(sortKey.Compile()(after), sortExpression.Type); query = query.Where(Expression.Lambda>( ApiHelper.StringCompatibleExpression(Expression.GreaterThan, sortExpression, key), sortKey.Parameters.First() )); } if (limit.Count > 0) query = query.Take(limit.Count); return await query.ToListAsync(); } public virtual Task GetCount(Expression> where = null) { IQueryable query = Database.Set(); if (where != null) query = query.Where(where); return query.CountAsync(); } public virtual async Task Create([NotNull] T obj) { if (obj == null) throw new ArgumentNullException(nameof(obj)); await Validate(obj); return obj; } public virtual async Task CreateIfNotExists(T obj, bool silentFail = false) { try { if (obj == null) throw new ArgumentNullException(nameof(obj)); T old = await Get(obj.Slug); if (old != null) return old; return await Create(obj); } catch (DuplicatedItemException) { T old = await Get(obj!.Slug); if (old == null) throw new SystemException("Unknown database state."); return old; } catch { if (silentFail) return default; throw; } } public virtual async Task Edit(T edited, bool resetOld) { if (edited == null) throw new ArgumentNullException(nameof(edited)); bool lazyLoading = Database.ChangeTracker.LazyLoadingEnabled; Database.ChangeTracker.LazyLoadingEnabled = false; try { T old = await GetWithTracking(edited.ID); if (old == null) throw new ItemNotFound($"No resource found with the ID {edited.ID}."); foreach (NavigationEntry navigation in Database.Entry(old).Navigations) { if (navigation.Metadata.PropertyInfo.GetCustomAttribute() != null) { if (resetOld) { await navigation.LoadAsync(); continue; } IClrPropertyGetter getter = navigation.Metadata.GetGetter(); if (getter.HasDefaultValue(edited)) continue; await navigation.LoadAsync(); // TODO this may be usless for lists since the API does not return IDs but the // TODO LinkEquality does not check slugs (their are lazy loaded and only the ID is available) // if (Utility.ResourceEquals(getter.GetClrValue(edited), getter.GetClrValue(old))) // navigation.Metadata.PropertyInfo.SetValue(edited, default); } else navigation.Metadata.PropertyInfo.SetValue(edited, default); } if (resetOld) Utility.Nullify(old); Utility.Complete(old, edited); await Validate(old); await Database.SaveChangesAsync(); return old; } finally { Database.ChangeTracker.LazyLoadingEnabled = lazyLoading; } } protected bool ShouldValidate(T2 value) { return value != null && Database.Entry(value).State == EntityState.Detached; } protected virtual Task Validate(T resource) { if (string.IsNullOrEmpty(resource.Slug)) throw new ArgumentException("Resource can't have null as a slug."); if (int.TryParse(resource.Slug, out int _)) { try { MethodInfo setter = typeof(T).GetProperty(nameof(resource.Slug))!.GetSetMethod(); if (setter != null) setter.Invoke(resource, new object[] {resource.Slug + '!'}); else throw new ArgumentException("Resources slug can't be number only."); } catch { throw new ArgumentException("Resources slug can't be number only."); } } return Task.CompletedTask; } public virtual async Task Delete(int id) { T resource = await Get(id); await Delete(resource); } public virtual async Task Delete(string slug) { T resource = await Get(slug); await Delete(resource); } public abstract Task Delete(T obj); public virtual async Task DeleteRange(IEnumerable objs) { foreach (T obj in objs) await Delete(obj); } public virtual async Task DeleteRange(IEnumerable ids) { foreach (int id in ids) await Delete(id); } public virtual async Task DeleteRange(IEnumerable slugs) { foreach (string slug in slugs) await Delete(slug); } } }