diff --git a/Kyoo.CommonAPI/LocalRepository.cs b/Kyoo.CommonAPI/LocalRepository.cs index 5417ced1..f35eb3a9 100644 --- a/Kyoo.CommonAPI/LocalRepository.cs +++ b/Kyoo.CommonAPI/LocalRepository.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Threading.Tasks; +using JetBrains.Annotations; using Kyoo.CommonApi; using Kyoo.Models; using Kyoo.Models.Exceptions; @@ -12,81 +13,52 @@ using Microsoft.EntityFrameworkCore; namespace Kyoo.Controllers { - public abstract class LocalRepository : IRepository + public abstract class LocalRepository where T : class, IResource - where TInternal : class, T { - private readonly DbContext _database; - - protected abstract Expression> DefaultSort { get; } - + protected readonly DbContext Database; + protected abstract Expression> DefaultSort { get; } + + protected LocalRepository(DbContext database) { - _database = database; + Database = database; } public virtual void Dispose() { - _database.Dispose(); + Database.Dispose(); } public virtual ValueTask DisposeAsync() { - return _database.DisposeAsync(); - } - - public Task Get(int id) - { - return _Get(id).Cast(); + return Database.DisposeAsync(); } - public Task Get(string slug) + public virtual Task Get(int id) { - return _Get(slug).Cast(); - } - - protected virtual Task _Get(int id) - { - return _database.Set().FirstOrDefaultAsync(x => x.ID == id); + return Database.Set().FirstOrDefaultAsync(x => x.ID == id); } - protected virtual Task _Get(string slug) + public virtual Task Get(string slug) { - return _database.Set().FirstOrDefaultAsync(x => x.Slug == slug); + return Database.Set().FirstOrDefaultAsync(x => x.Slug == slug); } - - public abstract Task> Search(string query); public virtual Task> GetAll(Expression> where = null, Sort sort = default, Pagination limit = default) { - return ApplyFilters(_database.Set(), where, sort, limit); + return ApplyFilters(Database.Set(), where, sort, limit); } - protected async Task> ApplyFilters(IQueryable query, + protected Task> ApplyFilters(IQueryable query, Expression> where = null, Sort sort = default, Pagination limit = default) { - ICollection items = await ApplyFilters(query, - _Get, - DefaultSort, - where.Convert>(), - sort.To(), - limit); - - return items.ToList(); - } - - protected async Task> ApplyFilters(IQueryable query, - Expression> where = null, - Sort sort = default, - Pagination limit = default) - { - ICollection items = await ApplyFilters(query, _Get, DefaultSort, where, sort, limit); - return items.ToList(); + return ApplyFilters(query, Get, DefaultSort, where, sort, limit); } protected async Task> ApplyFilters(IQueryable query, @@ -123,7 +95,7 @@ namespace Kyoo.Controllers return await query.ToListAsync(); } - + public abstract Task Create(T obj); public virtual async Task CreateIfNotExists(T obj) @@ -146,13 +118,13 @@ namespace Kyoo.Controllers return old; } } - + public virtual async Task Edit(T edited, bool resetOld) { if (edited == null) throw new ArgumentNullException(nameof(edited)); - TInternal old = (TInternal)await Get(edited.Slug); + T old = await Get(edited.Slug); if (old == null) throw new ItemNotFound($"No ressource found with the slug {edited.Slug}."); @@ -161,25 +133,25 @@ namespace Kyoo.Controllers Utility.Nullify(old); Utility.Merge(old, edited); await Validate(old); - await _database.SaveChangesAsync(); + await Database.SaveChangesAsync(); return old; } - - protected virtual Task Validate(TInternal ressource) + + protected virtual Task Validate(T ressource) { - foreach (PropertyInfo property in typeof(TInternal).GetProperties() + foreach (PropertyInfo property in typeof(T).GetProperties() .Where(x => typeof(IEnumerable).IsAssignableFrom(x.PropertyType) && !typeof(string).IsAssignableFrom(x.PropertyType))) { object value = property.GetValue(ressource); if (value is ICollection || value == null) continue; - value = Utility.RunGenericMethod(typeof(Enumerable), "ToList", Utility.GetEnumerableType((IEnumerable)value), new [] { value}); + value = Utility.RunGenericMethod(typeof(Enumerable), "ToList", Utility.GetEnumerableType((IEnumerable)value), value); property.SetValue(ressource, value); } return Task.CompletedTask; } - + public virtual async Task Delete(int id) { T ressource = await Get(id); @@ -212,4 +184,83 @@ namespace Kyoo.Controllers await Delete(slug); } } + + public abstract class LocalRepository : LocalRepository, IRepository + where T : class, IResource + where TInternal : class, T, new() + { + protected LocalRepository(DbContext database) : base(database) { } + + public new Task Get(int id) + { + return base.Get(id).Cast(); + } + + public new Task Get(string slug) + { + return base.Get(slug).Cast(); + } + + public abstract Task> Search(string query); + + public virtual Task> GetAll(Expression> where = null, + Sort sort = default, + Pagination limit = default) + { + return ApplyFilters(Database.Set(), where, sort, limit); + } + + protected virtual async Task> ApplyFilters(IQueryable query, + Expression> where = null, + Sort sort = default, + Pagination limit = default) + { + ICollection items = await ApplyFilters(query, + base.Get, + DefaultSort, + where.Convert>(), + sort.To(), + limit); + + return items.ToList(); + } + + public abstract override Task Create(TInternal obj); + + Task IRepository.Create(T item) + { + TInternal obj = new TInternal(); + Utility.Assign(obj, item); + return Create(obj).Cast(); + } + + Task IRepository.CreateIfNotExists(T item) + { + TInternal obj = new TInternal(); + Utility.Assign(obj, item); + return CreateIfNotExists(obj).Cast(); + } + + public Task Edit(T edited, bool resetOld) + { + TInternal obj = new TInternal(); + Utility.Assign(obj, edited); + return base.Edit(obj, resetOld).Cast(); + } + + public abstract override Task Delete([NotNull] TInternal obj); + + Task IRepository.Delete(T obj) + { + TInternal item = new TInternal(); + Utility.Assign(item, obj); + return Delete(item); + } + + public virtual async Task DeleteRange(IEnumerable objs) + { + foreach (T obj in objs) + await ((IRepository)this).Delete(obj); + } + } } \ No newline at end of file diff --git a/Kyoo/Controllers/Repositories/CollectionRepository.cs b/Kyoo/Controllers/Repositories/CollectionRepository.cs index 844b848d..8d8b6175 100644 --- a/Kyoo/Controllers/Repositories/CollectionRepository.cs +++ b/Kyoo/Controllers/Repositories/CollectionRepository.cs @@ -50,7 +50,7 @@ namespace Kyoo.Controllers .ToListAsync(); } - public override async Task Create(Collection obj) + public override async Task Create(CollectionDE obj) { if (obj == null) throw new ArgumentNullException(nameof(obj)); @@ -60,11 +60,10 @@ namespace Kyoo.Controllers return obj; } - public override async Task Delete(Collection item) + public override async Task Delete(CollectionDE obj) { - if (item == null) - throw new ArgumentNullException(nameof(item)); - CollectionDE obj = new CollectionDE(item); + if (obj == null) + throw new ArgumentNullException(nameof(obj)); _database.Entry(obj).State = EntityState.Deleted; if (obj.Links != null) diff --git a/Kyoo/Controllers/Repositories/EpisodeRepository.cs b/Kyoo/Controllers/Repositories/EpisodeRepository.cs index 9a1bebde..8c9e5be8 100644 --- a/Kyoo/Controllers/Repositories/EpisodeRepository.cs +++ b/Kyoo/Controllers/Repositories/EpisodeRepository.cs @@ -10,7 +10,7 @@ using Microsoft.EntityFrameworkCore; namespace Kyoo.Controllers { - public class EpisodeRepository : LocalRepository, IEpisodeRepository + public class EpisodeRepository : LocalRepository, IEpisodeRepository { private readonly DatabaseContext _database; private readonly IProviderRepository _providers; @@ -87,7 +87,7 @@ namespace Kyoo.Controllers && x.AbsoluteNumber == absoluteNumber); } - public override async Task> Search(string query) + public async Task> Search(string query) { return await _database.Episodes .Where(x => EF.Functions.ILike(x.Title, $"%{query}%")) @@ -133,7 +133,7 @@ namespace Kyoo.Controllers Sort sort = default, Pagination limit = default) { - ICollection episodes = await ApplyFilters(_database.Episodes.Where(x => x.ShowID == showID), + ICollection episodes = await ApplyFilters(_database.Episodes.Where(x => x.ShowID == showID), where, sort, limit);