mirror of
https://github.com/Kareadita/Kavita.git
synced 2025-05-31 20:24:27 -04:00
Custom Headers Test 4 (#2524)
This commit is contained in:
parent
d85208513f
commit
3765ad929a
@ -72,7 +72,7 @@ public interface IUserRepository
|
|||||||
Task<AppUser?> GetUserByIdAsync(int userId, AppUserIncludes includeFlags = AppUserIncludes.None);
|
Task<AppUser?> GetUserByIdAsync(int userId, AppUserIncludes includeFlags = AppUserIncludes.None);
|
||||||
Task<int> GetUserIdByUsernameAsync(string username);
|
Task<int> GetUserIdByUsernameAsync(string username);
|
||||||
Task<IList<AppUserBookmark>> GetAllBookmarksByIds(IList<int> bookmarkIds);
|
Task<IList<AppUserBookmark>> GetAllBookmarksByIds(IList<int> bookmarkIds);
|
||||||
Task<AppUser?> GetUserByEmailAsync(string email);
|
Task<AppUser?> GetUserByEmailAsync(string email, AppUserIncludes includes = AppUserIncludes.None);
|
||||||
Task<IEnumerable<AppUserPreferences>> GetAllPreferencesByThemeAsync(int themeId);
|
Task<IEnumerable<AppUserPreferences>> GetAllPreferencesByThemeAsync(int themeId);
|
||||||
Task<bool> HasAccessToLibrary(int libraryId, int userId);
|
Task<bool> HasAccessToLibrary(int libraryId, int userId);
|
||||||
Task<bool> HasAccessToSeries(int userId, int seriesId);
|
Task<bool> HasAccessToSeries(int userId, int seriesId);
|
||||||
@ -240,10 +240,12 @@ public class UserRepository : IUserRepository
|
|||||||
.ToListAsync();
|
.ToListAsync();
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<AppUser?> GetUserByEmailAsync(string email)
|
public async Task<AppUser?> GetUserByEmailAsync(string email, AppUserIncludes includes = AppUserIncludes.None)
|
||||||
{
|
{
|
||||||
var lowerEmail = email.ToLower();
|
var lowerEmail = email.ToLower();
|
||||||
return await _context.AppUser.SingleOrDefaultAsync(u => u.Email != null && u.Email.ToLower().Equals(lowerEmail));
|
return await _context.AppUser
|
||||||
|
.Includes(includes)
|
||||||
|
.FirstOrDefaultAsync(u => u.Email != null && u.Email.ToLower().Equals(lowerEmail));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
101
API/Middleware/CustomAuthHeaderMiddleware.cs
Normal file
101
API/Middleware/CustomAuthHeaderMiddleware.cs
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
using System;
|
||||||
|
using System.Linq;
|
||||||
|
using System.Net;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using API.Data;
|
||||||
|
using API.Services;
|
||||||
|
using Microsoft.AspNetCore.Http;
|
||||||
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
|
using Microsoft.Extensions.Logging;
|
||||||
|
|
||||||
|
namespace API.Middleware;
|
||||||
|
|
||||||
|
public class CustomAuthHeaderMiddleware(RequestDelegate next)
|
||||||
|
{
|
||||||
|
// Hardcoded list of allowed IP addresses in CIDR format
|
||||||
|
private readonly string[] allowedIpAddresses = { "192.168.1.0/24", "2001:db8::/32", "116.202.233.5", "104.21.81.112" };
|
||||||
|
|
||||||
|
|
||||||
|
public async Task Invoke(HttpContext context, IUnitOfWork unitOfWork, ILogger<CustomAuthHeaderMiddleware> logger, ITokenService tokenService)
|
||||||
|
{
|
||||||
|
// Extract user information from the custom header
|
||||||
|
string remoteUser = context.Request.Headers["Remote-User"];
|
||||||
|
|
||||||
|
// If header missing or user already authenticated, move on
|
||||||
|
if (string.IsNullOrEmpty(remoteUser) || context.User.Identity is {IsAuthenticated: true})
|
||||||
|
{
|
||||||
|
await next(context);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate IP address
|
||||||
|
if (IsValidIpAddress(context.Connection.RemoteIpAddress))
|
||||||
|
{
|
||||||
|
// Perform additional authentication logic if needed
|
||||||
|
// For now, you can log the authenticated user
|
||||||
|
var user = await unitOfWork.UserRepository.GetUserByEmailAsync(remoteUser);
|
||||||
|
if (user == null)
|
||||||
|
{
|
||||||
|
// Tell security log maybe?
|
||||||
|
context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Check if the RemoteUser has an account on the server
|
||||||
|
// if (!context.Request.Path.Equals("/login", StringComparison.OrdinalIgnoreCase))
|
||||||
|
// {
|
||||||
|
// // Attach the Auth header and allow it to pass through
|
||||||
|
// var token = await tokenService.CreateToken(user);
|
||||||
|
// context.Request.Headers.Add("Authorization", $"Bearer {token}");
|
||||||
|
// //context.Response.Redirect($"/login?apiKey={user.ApiKey}");
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
// Attach the Auth header and allow it to pass through
|
||||||
|
var token = await tokenService.CreateToken(user);
|
||||||
|
context.Request.Headers.Append("Authorization", $"Bearer {token}");
|
||||||
|
await next(context);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
|
||||||
|
await next(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
private bool IsValidIpAddress(IPAddress ipAddress)
|
||||||
|
{
|
||||||
|
// Check if the IP address is in the whitelist
|
||||||
|
return allowedIpAddresses.Any(ipRange => IpAddressRange.Parse(ipRange).Contains(ipAddress));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper class for IP address range parsing
|
||||||
|
public class IpAddressRange
|
||||||
|
{
|
||||||
|
private readonly uint _startAddress;
|
||||||
|
private readonly uint _endAddress;
|
||||||
|
|
||||||
|
private IpAddressRange(uint startAddress, uint endAddress)
|
||||||
|
{
|
||||||
|
_startAddress = startAddress;
|
||||||
|
_endAddress = endAddress;
|
||||||
|
}
|
||||||
|
|
||||||
|
public bool Contains(IPAddress address)
|
||||||
|
{
|
||||||
|
var ipAddressBytes = address.GetAddressBytes();
|
||||||
|
var ipAddress = BitConverter.ToUInt32(ipAddressBytes.Reverse().ToArray(), 0);
|
||||||
|
return ipAddress >= _startAddress && ipAddress <= _endAddress;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static IpAddressRange Parse(string ipRange)
|
||||||
|
{
|
||||||
|
var parts = ipRange.Split('/');
|
||||||
|
var ipAddress = IPAddress.Parse(parts[0]);
|
||||||
|
var maskBits = int.Parse(parts[1]);
|
||||||
|
|
||||||
|
var ipBytes = ipAddress.GetAddressBytes().Reverse().ToArray();
|
||||||
|
var startAddress = BitConverter.ToUInt32(ipBytes, 0);
|
||||||
|
var endAddress = startAddress | (uint.MaxValue >> maskBits);
|
||||||
|
|
||||||
|
return new IpAddressRange(startAddress, endAddress);
|
||||||
|
}
|
||||||
|
}
|
@ -9,27 +9,17 @@ using Microsoft.Extensions.Logging;
|
|||||||
|
|
||||||
namespace API.Middleware;
|
namespace API.Middleware;
|
||||||
|
|
||||||
public class ExceptionMiddleware
|
public class ExceptionMiddleware(RequestDelegate next, ILogger<ExceptionMiddleware> logger)
|
||||||
{
|
{
|
||||||
private readonly RequestDelegate _next;
|
|
||||||
private readonly ILogger<ExceptionMiddleware> _logger;
|
|
||||||
|
|
||||||
|
|
||||||
public ExceptionMiddleware(RequestDelegate next, ILogger<ExceptionMiddleware> logger)
|
|
||||||
{
|
|
||||||
_next = next;
|
|
||||||
_logger = logger;
|
|
||||||
}
|
|
||||||
|
|
||||||
public async Task InvokeAsync(HttpContext context)
|
public async Task InvokeAsync(HttpContext context)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
await _next(context); // downstream middlewares or http call
|
await next(context); // downstream middlewares or http call
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
_logger.LogError(ex, "There was an exception");
|
logger.LogError(ex, "There was an exception");
|
||||||
context.Response.ContentType = "application/json";
|
context.Response.ContentType = "application/json";
|
||||||
context.Response.StatusCode = (int) HttpStatusCode.InternalServerError;
|
context.Response.StatusCode = (int) HttpStatusCode.InternalServerError;
|
||||||
|
|
||||||
|
@ -10,24 +10,16 @@ namespace API.Middleware;
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Responsible for maintaining an in-memory. Not in use
|
/// Responsible for maintaining an in-memory. Not in use
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public class JwtRevocationMiddleware
|
public class JwtRevocationMiddleware(
|
||||||
|
RequestDelegate next,
|
||||||
|
IEasyCachingProviderFactory cacheFactory,
|
||||||
|
ILogger<JwtRevocationMiddleware> logger)
|
||||||
{
|
{
|
||||||
private readonly RequestDelegate _next;
|
|
||||||
private readonly IEasyCachingProviderFactory _cacheFactory;
|
|
||||||
private readonly ILogger<JwtRevocationMiddleware> _logger;
|
|
||||||
|
|
||||||
public JwtRevocationMiddleware(RequestDelegate next, IEasyCachingProviderFactory cacheFactory, ILogger<JwtRevocationMiddleware> logger)
|
|
||||||
{
|
|
||||||
_next = next;
|
|
||||||
_cacheFactory = cacheFactory;
|
|
||||||
_logger = logger;
|
|
||||||
}
|
|
||||||
|
|
||||||
public async Task InvokeAsync(HttpContext context)
|
public async Task InvokeAsync(HttpContext context)
|
||||||
{
|
{
|
||||||
if (context.User.Identity is {IsAuthenticated: false})
|
if (context.User.Identity is {IsAuthenticated: false})
|
||||||
{
|
{
|
||||||
await _next(context);
|
await next(context);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,18 +29,18 @@ public class JwtRevocationMiddleware
|
|||||||
// Check if the token is revoked
|
// Check if the token is revoked
|
||||||
if (await IsTokenRevoked(token))
|
if (await IsTokenRevoked(token))
|
||||||
{
|
{
|
||||||
_logger.LogWarning("Revoked token detected: {Token}", token);
|
logger.LogWarning("Revoked token detected: {Token}", token);
|
||||||
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
|
context.Response.StatusCode = StatusCodes.Status401Unauthorized;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
await _next(context);
|
await next(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task<bool> IsTokenRevoked(string token)
|
private async Task<bool> IsTokenRevoked(string token)
|
||||||
{
|
{
|
||||||
// Check if the token exists in the revocation list stored in the cache
|
// Check if the token exists in the revocation list stored in the cache
|
||||||
var isRevoked = await _cacheFactory.GetCachingProvider(EasyCacheProfiles.RevokedJwt)
|
var isRevoked = await cacheFactory.GetCachingProvider(EasyCacheProfiles.RevokedJwt)
|
||||||
.GetAsync<string>(token);
|
.GetAsync<string>(token);
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,37 +7,29 @@ using API.Errors;
|
|||||||
using Kavita.Common;
|
using Kavita.Common;
|
||||||
using Microsoft.AspNetCore.Http;
|
using Microsoft.AspNetCore.Http;
|
||||||
using Serilog;
|
using Serilog;
|
||||||
using ILogger = Serilog.ILogger;
|
using ILogger = Serilog.Core.Logger;
|
||||||
|
|
||||||
namespace API.Middleware;
|
namespace API.Middleware;
|
||||||
|
|
||||||
public class SecurityEventMiddleware
|
public class SecurityEventMiddleware(RequestDelegate next)
|
||||||
{
|
{
|
||||||
private readonly RequestDelegate _next;
|
private readonly ILogger _logger = new LoggerConfiguration()
|
||||||
private readonly ILogger _logger;
|
.MinimumLevel.Debug()
|
||||||
|
.WriteTo.File(Path.Join(Directory.GetCurrentDirectory(), "config/logs/", "security.log"), rollingInterval: RollingInterval.Day)
|
||||||
public SecurityEventMiddleware(RequestDelegate next)
|
.CreateLogger();
|
||||||
{
|
|
||||||
_next = next;
|
|
||||||
|
|
||||||
_logger = new LoggerConfiguration()
|
|
||||||
.MinimumLevel.Debug()
|
|
||||||
.WriteTo.File(Path.Join(Directory.GetCurrentDirectory(), "config/logs/", "security.log"), rollingInterval: RollingInterval.Day)
|
|
||||||
.CreateLogger();
|
|
||||||
}
|
|
||||||
|
|
||||||
public async Task InvokeAsync(HttpContext context)
|
public async Task InvokeAsync(HttpContext context)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
await _next(context);
|
await next(context);
|
||||||
}
|
}
|
||||||
catch (KavitaUnauthenticatedUserException ex)
|
catch (KavitaUnauthenticatedUserException ex)
|
||||||
{
|
{
|
||||||
var ipAddress = context.Connection.RemoteIpAddress?.ToString();
|
var ipAddress = context.Connection.RemoteIpAddress?.ToString();
|
||||||
var requestMethod = context.Request.Method;
|
var requestMethod = context.Request.Method;
|
||||||
var requestPath = context.Request.Path;
|
var requestPath = context.Request.Path;
|
||||||
var userAgent = context.Request.Headers["User-Agent"];
|
var userAgent = context.Request.Headers.UserAgent;
|
||||||
var securityEvent = new
|
var securityEvent = new
|
||||||
{
|
{
|
||||||
IpAddress = ipAddress,
|
IpAddress = ipAddress,
|
||||||
@ -57,8 +49,7 @@ public class SecurityEventMiddleware
|
|||||||
|
|
||||||
var options = new JsonSerializerOptions
|
var options = new JsonSerializerOptions
|
||||||
{
|
{
|
||||||
PropertyNamingPolicy =
|
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
|
||||||
JsonNamingPolicy.CamelCase
|
|
||||||
};
|
};
|
||||||
|
|
||||||
var json = JsonSerializer.Serialize(response, options);
|
var json = JsonSerializer.Serialize(response, options);
|
||||||
|
@ -261,6 +261,7 @@ public class Startup
|
|||||||
|
|
||||||
app.UseMiddleware<ExceptionMiddleware>();
|
app.UseMiddleware<ExceptionMiddleware>();
|
||||||
app.UseMiddleware<SecurityEventMiddleware>();
|
app.UseMiddleware<SecurityEventMiddleware>();
|
||||||
|
app.UseMiddleware<CustomAuthHeaderMiddleware>();
|
||||||
|
|
||||||
|
|
||||||
if (env.IsDevelopment())
|
if (env.IsDevelopment())
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
<TargetFramework>net8.0</TargetFramework>
|
<TargetFramework>net8.0</TargetFramework>
|
||||||
<Company>kavitareader.com</Company>
|
<Company>kavitareader.com</Company>
|
||||||
<Product>Kavita</Product>
|
<Product>Kavita</Product>
|
||||||
<AssemblyVersion>0.7.11.7</AssemblyVersion>
|
<AssemblyVersion>0.7.11.10</AssemblyVersion>
|
||||||
<NeutralLanguage>en</NeutralLanguage>
|
<NeutralLanguage>en</NeutralLanguage>
|
||||||
<TieredPGO>true</TieredPGO>
|
<TieredPGO>true</TieredPGO>
|
||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
"name": "GPL-3.0",
|
"name": "GPL-3.0",
|
||||||
"url": "https://github.com/Kareadita/Kavita/blob/develop/LICENSE"
|
"url": "https://github.com/Kareadita/Kavita/blob/develop/LICENSE"
|
||||||
},
|
},
|
||||||
"version": "0.7.11.6"
|
"version": "0.7.11.10"
|
||||||
},
|
},
|
||||||
"servers": [
|
"servers": [
|
||||||
{
|
{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user