Custom Headers Test 4 (#2524)

This commit is contained in:
Joe Milazzo 2024-01-05 14:56:46 -06:00 committed by GitHub
parent d85208513f
commit 3765ad929a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 129 additions and 52 deletions

View File

@ -72,7 +72,7 @@ public interface IUserRepository
Task<AppUser?> GetUserByIdAsync(int userId, AppUserIncludes includeFlags = AppUserIncludes.None); Task<AppUser?> GetUserByIdAsync(int userId, AppUserIncludes includeFlags = AppUserIncludes.None);
Task<int> GetUserIdByUsernameAsync(string username); Task<int> GetUserIdByUsernameAsync(string username);
Task<IList<AppUserBookmark>> GetAllBookmarksByIds(IList<int> bookmarkIds); Task<IList<AppUserBookmark>> GetAllBookmarksByIds(IList<int> bookmarkIds);
Task<AppUser?> GetUserByEmailAsync(string email); Task<AppUser?> GetUserByEmailAsync(string email, AppUserIncludes includes = AppUserIncludes.None);
Task<IEnumerable<AppUserPreferences>> GetAllPreferencesByThemeAsync(int themeId); Task<IEnumerable<AppUserPreferences>> GetAllPreferencesByThemeAsync(int themeId);
Task<bool> HasAccessToLibrary(int libraryId, int userId); Task<bool> HasAccessToLibrary(int libraryId, int userId);
Task<bool> HasAccessToSeries(int userId, int seriesId); Task<bool> HasAccessToSeries(int userId, int seriesId);
@ -240,10 +240,12 @@ public class UserRepository : IUserRepository
.ToListAsync(); .ToListAsync();
} }
public async Task<AppUser?> GetUserByEmailAsync(string email) public async Task<AppUser?> GetUserByEmailAsync(string email, AppUserIncludes includes = AppUserIncludes.None)
{ {
var lowerEmail = email.ToLower(); var lowerEmail = email.ToLower();
return await _context.AppUser.SingleOrDefaultAsync(u => u.Email != null && u.Email.ToLower().Equals(lowerEmail)); return await _context.AppUser
.Includes(includes)
.FirstOrDefaultAsync(u => u.Email != null && u.Email.ToLower().Equals(lowerEmail));
} }

View File

@ -0,0 +1,101 @@
using System;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
using API.Data;
using API.Services;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace API.Middleware;
public class CustomAuthHeaderMiddleware(RequestDelegate next)
{
// Hardcoded list of allowed IP addresses in CIDR format
private readonly string[] allowedIpAddresses = { "192.168.1.0/24", "2001:db8::/32", "116.202.233.5", "104.21.81.112" };
public async Task Invoke(HttpContext context, IUnitOfWork unitOfWork, ILogger<CustomAuthHeaderMiddleware> logger, ITokenService tokenService)
{
// Extract user information from the custom header
string remoteUser = context.Request.Headers["Remote-User"];
// If header missing or user already authenticated, move on
if (string.IsNullOrEmpty(remoteUser) || context.User.Identity is {IsAuthenticated: true})
{
await next(context);
return;
}
// Validate IP address
if (IsValidIpAddress(context.Connection.RemoteIpAddress))
{
// Perform additional authentication logic if needed
// For now, you can log the authenticated user
var user = await unitOfWork.UserRepository.GetUserByEmailAsync(remoteUser);
if (user == null)
{
// Tell security log maybe?
context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
return;
}
// Check if the RemoteUser has an account on the server
// if (!context.Request.Path.Equals("/login", StringComparison.OrdinalIgnoreCase))
// {
// // Attach the Auth header and allow it to pass through
// var token = await tokenService.CreateToken(user);
// context.Request.Headers.Add("Authorization", $"Bearer {token}");
// //context.Response.Redirect($"/login?apiKey={user.ApiKey}");
// return;
// }
// Attach the Auth header and allow it to pass through
var token = await tokenService.CreateToken(user);
context.Request.Headers.Append("Authorization", $"Bearer {token}");
await next(context);
return;
}
context.Response.StatusCode = (int)HttpStatusCode.Unauthorized;
await next(context);
}
private bool IsValidIpAddress(IPAddress ipAddress)
{
// Check if the IP address is in the whitelist
return allowedIpAddresses.Any(ipRange => IpAddressRange.Parse(ipRange).Contains(ipAddress));
}
}
// Helper class for IP address range parsing
public class IpAddressRange
{
private readonly uint _startAddress;
private readonly uint _endAddress;
private IpAddressRange(uint startAddress, uint endAddress)
{
_startAddress = startAddress;
_endAddress = endAddress;
}
public bool Contains(IPAddress address)
{
var ipAddressBytes = address.GetAddressBytes();
var ipAddress = BitConverter.ToUInt32(ipAddressBytes.Reverse().ToArray(), 0);
return ipAddress >= _startAddress && ipAddress <= _endAddress;
}
public static IpAddressRange Parse(string ipRange)
{
var parts = ipRange.Split('/');
var ipAddress = IPAddress.Parse(parts[0]);
var maskBits = int.Parse(parts[1]);
var ipBytes = ipAddress.GetAddressBytes().Reverse().ToArray();
var startAddress = BitConverter.ToUInt32(ipBytes, 0);
var endAddress = startAddress | (uint.MaxValue >> maskBits);
return new IpAddressRange(startAddress, endAddress);
}
}

View File

@ -9,27 +9,17 @@ using Microsoft.Extensions.Logging;
namespace API.Middleware; namespace API.Middleware;
public class ExceptionMiddleware public class ExceptionMiddleware(RequestDelegate next, ILogger<ExceptionMiddleware> logger)
{ {
private readonly RequestDelegate _next;
private readonly ILogger<ExceptionMiddleware> _logger;
public ExceptionMiddleware(RequestDelegate next, ILogger<ExceptionMiddleware> logger)
{
_next = next;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context) public async Task InvokeAsync(HttpContext context)
{ {
try try
{ {
await _next(context); // downstream middlewares or http call await next(context); // downstream middlewares or http call
} }
catch (Exception ex) catch (Exception ex)
{ {
_logger.LogError(ex, "There was an exception"); logger.LogError(ex, "There was an exception");
context.Response.ContentType = "application/json"; context.Response.ContentType = "application/json";
context.Response.StatusCode = (int) HttpStatusCode.InternalServerError; context.Response.StatusCode = (int) HttpStatusCode.InternalServerError;

View File

@ -10,24 +10,16 @@ namespace API.Middleware;
/// <summary> /// <summary>
/// Responsible for maintaining an in-memory. Not in use /// Responsible for maintaining an in-memory. Not in use
/// </summary> /// </summary>
public class JwtRevocationMiddleware public class JwtRevocationMiddleware(
RequestDelegate next,
IEasyCachingProviderFactory cacheFactory,
ILogger<JwtRevocationMiddleware> logger)
{ {
private readonly RequestDelegate _next;
private readonly IEasyCachingProviderFactory _cacheFactory;
private readonly ILogger<JwtRevocationMiddleware> _logger;
public JwtRevocationMiddleware(RequestDelegate next, IEasyCachingProviderFactory cacheFactory, ILogger<JwtRevocationMiddleware> logger)
{
_next = next;
_cacheFactory = cacheFactory;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context) public async Task InvokeAsync(HttpContext context)
{ {
if (context.User.Identity is {IsAuthenticated: false}) if (context.User.Identity is {IsAuthenticated: false})
{ {
await _next(context); await next(context);
return; return;
} }
@ -37,18 +29,18 @@ public class JwtRevocationMiddleware
// Check if the token is revoked // Check if the token is revoked
if (await IsTokenRevoked(token)) if (await IsTokenRevoked(token))
{ {
_logger.LogWarning("Revoked token detected: {Token}", token); logger.LogWarning("Revoked token detected: {Token}", token);
context.Response.StatusCode = StatusCodes.Status401Unauthorized; context.Response.StatusCode = StatusCodes.Status401Unauthorized;
return; return;
} }
await _next(context); await next(context);
} }
private async Task<bool> IsTokenRevoked(string token) private async Task<bool> IsTokenRevoked(string token)
{ {
// Check if the token exists in the revocation list stored in the cache // Check if the token exists in the revocation list stored in the cache
var isRevoked = await _cacheFactory.GetCachingProvider(EasyCacheProfiles.RevokedJwt) var isRevoked = await cacheFactory.GetCachingProvider(EasyCacheProfiles.RevokedJwt)
.GetAsync<string>(token); .GetAsync<string>(token);

View File

@ -7,37 +7,29 @@ using API.Errors;
using Kavita.Common; using Kavita.Common;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Serilog; using Serilog;
using ILogger = Serilog.ILogger; using ILogger = Serilog.Core.Logger;
namespace API.Middleware; namespace API.Middleware;
public class SecurityEventMiddleware public class SecurityEventMiddleware(RequestDelegate next)
{ {
private readonly RequestDelegate _next; private readonly ILogger _logger = new LoggerConfiguration()
private readonly ILogger _logger; .MinimumLevel.Debug()
.WriteTo.File(Path.Join(Directory.GetCurrentDirectory(), "config/logs/", "security.log"), rollingInterval: RollingInterval.Day)
public SecurityEventMiddleware(RequestDelegate next) .CreateLogger();
{
_next = next;
_logger = new LoggerConfiguration()
.MinimumLevel.Debug()
.WriteTo.File(Path.Join(Directory.GetCurrentDirectory(), "config/logs/", "security.log"), rollingInterval: RollingInterval.Day)
.CreateLogger();
}
public async Task InvokeAsync(HttpContext context) public async Task InvokeAsync(HttpContext context)
{ {
try try
{ {
await _next(context); await next(context);
} }
catch (KavitaUnauthenticatedUserException ex) catch (KavitaUnauthenticatedUserException ex)
{ {
var ipAddress = context.Connection.RemoteIpAddress?.ToString(); var ipAddress = context.Connection.RemoteIpAddress?.ToString();
var requestMethod = context.Request.Method; var requestMethod = context.Request.Method;
var requestPath = context.Request.Path; var requestPath = context.Request.Path;
var userAgent = context.Request.Headers["User-Agent"]; var userAgent = context.Request.Headers.UserAgent;
var securityEvent = new var securityEvent = new
{ {
IpAddress = ipAddress, IpAddress = ipAddress,
@ -57,8 +49,7 @@ public class SecurityEventMiddleware
var options = new JsonSerializerOptions var options = new JsonSerializerOptions
{ {
PropertyNamingPolicy = PropertyNamingPolicy = JsonNamingPolicy.CamelCase
JsonNamingPolicy.CamelCase
}; };
var json = JsonSerializer.Serialize(response, options); var json = JsonSerializer.Serialize(response, options);

View File

@ -261,6 +261,7 @@ public class Startup
app.UseMiddleware<ExceptionMiddleware>(); app.UseMiddleware<ExceptionMiddleware>();
app.UseMiddleware<SecurityEventMiddleware>(); app.UseMiddleware<SecurityEventMiddleware>();
app.UseMiddleware<CustomAuthHeaderMiddleware>();
if (env.IsDevelopment()) if (env.IsDevelopment())

View File

@ -4,7 +4,7 @@
<TargetFramework>net8.0</TargetFramework> <TargetFramework>net8.0</TargetFramework>
<Company>kavitareader.com</Company> <Company>kavitareader.com</Company>
<Product>Kavita</Product> <Product>Kavita</Product>
<AssemblyVersion>0.7.11.7</AssemblyVersion> <AssemblyVersion>0.7.11.10</AssemblyVersion>
<NeutralLanguage>en</NeutralLanguage> <NeutralLanguage>en</NeutralLanguage>
<TieredPGO>true</TieredPGO> <TieredPGO>true</TieredPGO>
</PropertyGroup> </PropertyGroup>

View File

@ -7,7 +7,7 @@
"name": "GPL-3.0", "name": "GPL-3.0",
"url": "https://github.com/Kareadita/Kavita/blob/develop/LICENSE" "url": "https://github.com/Kareadita/Kavita/blob/develop/LICENSE"
}, },
"version": "0.7.11.6" "version": "0.7.11.10"
}, },
"servers": [ "servers": [
{ {