using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Threading.Tasks; using Kyoo.Abstractions.Controllers; using Kyoo.Abstractions.Models; using Kyoo.Abstractions.Models.Attributes; using Kyoo.Abstractions.Models.Exceptions; using Kyoo.Utils; using Kyoo.Core.Api; using Microsoft.EntityFrameworkCore; namespace Kyoo.Core.Controllers { /// /// A base class to create repositories using Entity Framework. /// /// The type of this repository public abstract class LocalRepository : IRepository where T : class, IResource { /// /// The Entity Framework's Database handle. /// protected readonly DbContext Database; /// /// The default sort order that will be used for this resource's type. /// protected abstract Expression> DefaultSort { get; } /// /// Create a new base with the given database handle. /// /// A database connection to load resources of type protected LocalRepository(DbContext database) { Database = database; } /// public Type RepositoryType => typeof(T); /// /// Get a resource from it's ID and make the instance track it. /// /// The ID of the resource /// If the item is not found /// The tracked resource with the given ID protected virtual async Task GetWithTracking(int id) { T ret = await Database.Set().AsTracking().FirstOrDefaultAsync(x => x.ID == id); if (ret == null) throw new ItemNotFoundException($"No {typeof(T).Name} found with the id {id}"); return ret; } /// public virtual async Task Get(int id) { T ret = await GetOrDefault(id); if (ret == null) throw new ItemNotFoundException($"No {typeof(T).Name} found with the id {id}"); return ret; } /// public virtual async Task Get(string slug) { T ret = await GetOrDefault(slug); if (ret == null) throw new ItemNotFoundException($"No {typeof(T).Name} found with the slug {slug}"); return ret; } /// public virtual async Task Get(Expression> where) { T ret = await GetOrDefault(where); if (ret == null) throw new ItemNotFoundException($"No {typeof(T).Name} found with the given predicate."); return ret; } /// public virtual Task GetOrDefault(int id) { return Database.Set().FirstOrDefaultAsync(x => x.ID == id); } /// public virtual Task GetOrDefault(string slug) { return Database.Set().FirstOrDefaultAsync(x => x.Slug == slug); } /// public virtual Task GetOrDefault(Expression> where) { return Database.Set().FirstOrDefaultAsync(where); } /// public abstract Task> Search(string query); /// public virtual Task> GetAll(Expression> where = null, Sort sort = default, Pagination limit = default) { return ApplyFilters(Database.Set(), where, sort, limit); } /// /// Apply filters to a query to ease sort, pagination & where queries for resources of this repository /// /// The base query to filter. /// An expression to filter based on arbitrary conditions /// The sort settings (sort order & sort by) /// Pagination information (where to start and how many to get) /// The filtered query protected Task> ApplyFilters(IQueryable query, Expression> where = null, Sort sort = default, Pagination limit = default) { return ApplyFilters(query, GetOrDefault, DefaultSort, where, sort, limit); } /// /// Apply filters to a query to ease sort, pagination & where queries for any resources types. /// For resources of type , see /// /// A function to asynchronously get a resource from the database using it's ID. /// The default sort order of this resource's type. /// The base query to filter. /// An expression to filter based on arbitrary conditions /// The sort settings (sort order & sort by) /// Pagination information (where to start and how many to get) /// The filtered query protected async Task> ApplyFilters(IQueryable query, Func> get, Expression> defaultSort, Expression> where = null, Sort sort = default, Pagination limit = default) { if (where != null) query = query.Where(where); Expression> sortKey = sort.Key ?? defaultSort; Expression sortExpression = sortKey.Body.NodeType == ExpressionType.Convert ? ((UnaryExpression)sortKey.Body).Operand : sortKey.Body; if (typeof(Enum).IsAssignableFrom(sortExpression.Type)) throw new ArgumentException("Invalid sort key."); query = sort.Descendant ? query.OrderByDescending(sortKey) : query.OrderBy(sortKey); if (limit.AfterID != 0) { TValue after = await get(limit.AfterID); Expression key = Expression.Constant(sortKey.Compile()(after), sortExpression.Type); query = query.Where(Expression.Lambda>( ApiHelper.StringCompatibleExpression(Expression.GreaterThan, sortExpression, key), sortKey.Parameters.First() )); } if (limit.Count > 0) query = query.Take(limit.Count); return await query.ToListAsync(); } /// public virtual Task GetCount(Expression> where = null) { IQueryable query = Database.Set(); if (where != null) query = query.Where(where); return query.CountAsync(); } /// public virtual async Task Create(T obj) { if (obj == null) throw new ArgumentNullException(nameof(obj)); await Validate(obj); return obj; } /// public virtual async Task CreateIfNotExists(T obj) { try { if (obj == null) throw new ArgumentNullException(nameof(obj)); T old = await GetOrDefault(obj.Slug); if (old != null) return old; return await Create(obj); } catch (DuplicatedItemException) { return await Get(obj.Slug); } } /// public virtual async Task Edit(T edited, bool resetOld) { if (edited == null) throw new ArgumentNullException(nameof(edited)); bool lazyLoading = Database.ChangeTracker.LazyLoadingEnabled; Database.ChangeTracker.LazyLoadingEnabled = false; try { T old = await GetWithTracking(edited.ID); if (resetOld) old = Merger.Nullify(old); Merger.Complete(old, edited, x => x.GetCustomAttribute() == null); await EditRelations(old, edited, resetOld); await Database.SaveChangesAsync(); return old; } finally { Database.ChangeTracker.LazyLoadingEnabled = lazyLoading; Database.ChangeTracker.Clear(); } } /// /// An overridable method to edit relation of a resource. /// /// /// The non edited resource /// /// /// The new version of . /// This item will be saved on the database and replace /// /// /// A boolean to indicate if all values of resource should be discarded or not. /// protected virtual Task EditRelations(T resource, T changed, bool resetOld) { return Validate(resource); } /// /// A method called just before saving a new resource to the database. /// It is also called on the default implementation of /// /// The resource that will be saved /// /// You can throw this if the resource is illegal and should not be saved. /// protected virtual Task Validate(T resource) { if (typeof(T).GetProperty(nameof(resource.Slug))!.GetCustomAttribute() != null) return Task.CompletedTask; if (string.IsNullOrEmpty(resource.Slug)) throw new ArgumentException("Resource can't have null as a slug."); if (int.TryParse(resource.Slug, out int _)) { try { MethodInfo setter = typeof(T).GetProperty(nameof(resource.Slug))!.GetSetMethod(); if (setter != null) setter.Invoke(resource, new object[] { resource.Slug + '!' }); else throw new ArgumentException("Resources slug can't be number only."); } catch { throw new ArgumentException("Resources slug can't be number only."); } } return Task.CompletedTask; } /// public virtual async Task Delete(int id) { T resource = await Get(id); await Delete(resource); } /// public virtual async Task Delete(string slug) { T resource = await Get(slug); await Delete(resource); } /// public abstract Task Delete(T obj); /// public async Task DeleteAll(Expression> where) { foreach (T resource in await GetAll(where)) await Delete(resource); } } }