diff --git a/back/.env.example b/back/.env.example
index ce1153e2..0b3a2aef 100644
--- a/back/.env.example
+++ b/back/.env.example
@@ -4,6 +4,7 @@
# http route prefix (will listen to $KYOO_PREFIX/movie for example)
KYOO_PREFIX=""
+# Postgres settings
# POSTGRES_URL=postgres://user:password@hostname:port/dbname?sslmode=verify-full&sslrootcert=/path/to/server.crt&sslcert=/path/to/client.crt&sslkey=/path/to/client.key
# The behavior of the below variables match what is documented here:
# https://www.postgresql.org/docs/current/libpq-envars.html
@@ -18,3 +19,12 @@ PGPORT=5432
# PGSSLROOTCERT=/my/serving.crt
# PGSSLCERT=/my/client.crt
# PGSSLKEY=/my/client.key
+
+# RabbitMQ settings
+# Full list of options: https://www.rabbitmq.com/uri-spec.html, https://www.rabbitmq.com/docs/uri-query-parameters
+# RABBITMQ_URL=amqps://user:password@rabbitmq-server:1234/vhost?cacertfile=/path/to/cacert.pem&certfile=/path/to/cert.pem&keyfile=/path/to/key.pem&verify=verify_peer&auth_mechanism=EXTERNAL
+# These values override what is provided the the URL variable
+RABBITMQ_DEFAULT_USER=guest
+RABBITMQ_DEFAULT_PASS=guest
+RABBITMQ_HOST=rabbitmq
+RABBITMQ_PORT=5672
diff --git a/back/src/Kyoo.RabbitMq/RabbitMqModule.cs b/back/src/Kyoo.RabbitMq/RabbitMqModule.cs
index ed9ef48f..8a489bbb 100644
--- a/back/src/Kyoo.RabbitMq/RabbitMqModule.cs
+++ b/back/src/Kyoo.RabbitMq/RabbitMqModule.cs
@@ -16,8 +16,11 @@
// You should have received a copy of the GNU General Public License
// along with Kyoo. If not, see .
+using System.Net.Security;
+using System.Security.Cryptography.X509Certificates;
using Kyoo.Abstractions.Controllers;
using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using RabbitMQ.Client;
@@ -30,18 +33,215 @@ public static class RabbitMqModule
{
builder.Services.AddSingleton(_ =>
{
- ConnectionFactory factory =
- new()
- {
- UserName = builder.Configuration.GetValue("RABBITMQ_DEFAULT_USER", "guest"),
- Password = builder.Configuration.GetValue("RABBITMQ_DEFAULT_PASS", "guest"),
- HostName = builder.Configuration.GetValue("RABBITMQ_HOST", "rabbitmq"),
- Port = builder.Configuration.GetValue("RABBITMQ_PORT", 5672),
- };
+ ConnectionFactory factory = new();
+
+ // See https://www.rabbitmq.com/docs/uri-spec
+ string? connectionString = builder.Configuration.GetValue("RABBITMQ_URL");
+ if (!string.IsNullOrEmpty(connectionString))
+ factory._ConfigureFactoryWithConnectionString(connectionString);
+ else
+ factory._ConfigureFactoryWithEnvironmentVars(builder.Configuration);
return factory.CreateConnection();
});
builder.Services.AddSingleton();
builder.Services.AddSingleton();
}
+
+ private static void _ConfigureFactoryWithConnectionString(
+ this ConnectionFactory factory,
+ string? connectionString
+ )
+ {
+ if (string.IsNullOrEmpty(connectionString))
+ return;
+
+ // Important: setting this property will not use any query parameters, so they must be parsed here instead..
+ factory.Uri = new Uri(connectionString);
+
+ // Support query parameters defined here:
+ // https://www.rabbitmq.com/docs/uri-query-parameters
+ Dictionary queryParameters =
+ QueryHelpers.ParseQuery(factory.Uri.Query);
+
+ queryParameters.TryGetValue(
+ "heartbeat",
+ out Microsoft.Extensions.Primitives.StringValues heartbeats
+ );
+ if (int.TryParse(heartbeats.LastOrDefault(), out int heartbeatValue))
+ factory.RequestedHeartbeat = TimeSpan.FromSeconds(heartbeatValue);
+
+ queryParameters.TryGetValue(
+ "connection_timeout",
+ out Microsoft.Extensions.Primitives.StringValues connectionTimeouts
+ );
+ if (int.TryParse(connectionTimeouts.LastOrDefault(), out int connectionTimeoutValue))
+ factory.RequestedConnectionTimeout = TimeSpan.FromSeconds(connectionTimeoutValue);
+
+ queryParameters.TryGetValue(
+ "channel_max",
+ out Microsoft.Extensions.Primitives.StringValues channelMaxValues
+ );
+ if (ushort.TryParse(channelMaxValues.LastOrDefault(), out ushort channelMaxValue))
+ factory.RequestedChannelMax = channelMaxValue;
+
+ if (!factory.Ssl.Enabled)
+ return;
+
+ queryParameters.TryGetValue(
+ "cacertfile",
+ out Microsoft.Extensions.Primitives.StringValues caCertFiles
+ );
+ var caCertFile = caCertFiles.LastOrDefault();
+ if (!string.IsNullOrEmpty(caCertFile))
+ {
+ // Load the cert once at startup instead of on every connection.
+ X509Certificate2 rootCA = new(caCertFile);
+
+ // This is a custom validator that obeys the set SslPolicyErrors, while also using the CA cert specified in the query string.
+ factory.Ssl.CertificateValidationCallback = (
+ sender,
+ certificate,
+ chain,
+ sslPolicyErrors
+ ) =>
+ {
+ // If no cert was provided
+ if (sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNotAvailable))
+ {
+ // Accept the cert anyway if the client was explicitly configured to ignore this.
+ if (
+ factory.Ssl.AcceptablePolicyErrors.HasFlag(
+ SslPolicyErrors.RemoteCertificateNotAvailable
+ )
+ )
+ return true;
+ // Otherwise, reject it.
+ return false;
+ }
+
+ // If the cert hostname does not match
+ if (sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch))
+ {
+ // Accept the cert anyway if the client was explicitly configured to ignore this.
+ if (
+ factory.Ssl.AcceptablePolicyErrors.HasFlag(
+ SslPolicyErrors.RemoteCertificateNameMismatch
+ )
+ )
+ return true;
+ // Otherwise, reject it.
+ return false;
+ }
+
+ // This shouldn't ever happen, and is mostly just here to satisfy the linter
+ if (chain == null || certificate == null)
+ return false;
+
+ // Verify that the certificate came from the specified CA.
+ chain.ChainPolicy.ExtraStore.AddRange(
+ chain.ChainElements.Select(x => x.Certificate).ToArray()
+ );
+ chain.ChainPolicy.CustomTrustStore.Clear();
+ chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
+ chain.ChainPolicy.CustomTrustStore.Add(rootCA);
+
+ return chain.Build(new X509Certificate2(certificate));
+ };
+ }
+
+ queryParameters.TryGetValue("certfile", out var certfiles);
+ var certfile = certfiles.LastOrDefault();
+ queryParameters.TryGetValue("keyfile", out var keyfiles);
+ var keyfile = keyfiles.LastOrDefault();
+ if (!string.IsNullOrEmpty(certfile) && !string.IsNullOrEmpty(keyfile))
+ factory.Ssl.Certs = [X509Certificate2.CreateFromPemFile(certfile, keyfile)];
+
+ queryParameters.TryGetValue("verify", out var verifyValues);
+ switch (verifyValues.LastOrDefault())
+ {
+ case "verify_none":
+ factory.Ssl.AcceptablePolicyErrors = ~SslPolicyErrors.None;
+ break;
+ case "verify_peer":
+ factory.Ssl.AcceptablePolicyErrors = SslPolicyErrors.None;
+ break;
+ }
+
+ queryParameters.TryGetValue(
+ "server_name_indication",
+ out Microsoft.Extensions.Primitives.StringValues sniValues
+ );
+ var sni = sniValues.LastOrDefault();
+ if (!string.IsNullOrEmpty(sni))
+ {
+ if (sni == "disabled") // Special value, see https://www.rabbitmq.com/docs/ssl#erlang-ssl
+ {
+ factory.Ssl.ServerName = null;
+ factory.Ssl.AcceptablePolicyErrors |= SslPolicyErrors.RemoteCertificateNameMismatch;
+ }
+ else
+ factory.Ssl.ServerName = sni;
+ }
+
+ queryParameters.TryGetValue(
+ "auth_mechanism",
+ out Microsoft.Extensions.Primitives.StringValues authMechanisms
+ );
+ if (authMechanisms.Count > 0)
+ {
+ factory.AuthMechanisms.Clear();
+ foreach (var authMechanism in authMechanisms)
+ {
+ switch (authMechanism)
+ {
+ case "external":
+ factory.AuthMechanisms.Add(new ExternalMechanismFactory());
+ break;
+ case "plain":
+ factory.AuthMechanisms.Add(new PlainMechanismFactory());
+ break;
+ default:
+ throw new NotSupportedException(
+ $"Unsupported authentication mechanism: {authMechanism}"
+ );
+ }
+ }
+ }
+ }
+
+ private static void _ConfigureFactoryWithEnvironmentVars(
+ this ConnectionFactory factory,
+ IConfigurationManager configuration
+ )
+ {
+ factory.UserName = _GetNonEmptyString(
+ configuration.GetValue("RABBITMQ_DEFAULT_USER"),
+ factory.UserName,
+ "guest"
+ );
+ factory.Password = _GetNonEmptyString(
+ configuration.GetValue("RABBITMQ_DEFAULT_PASS"),
+ factory.Password,
+ "guest"
+ );
+ factory.HostName = _GetNonEmptyString(
+ configuration.GetValue("RABBITMQ_HOST"),
+ factory.HostName,
+ "rabbitmq"
+ );
+ var port = configuration.GetValue("RABBITMQ_PORT");
+ if (port != null)
+ factory.Port = port.Value;
+ else if (factory.Port == 0)
+ factory.Port = 5672;
+ }
+
+ private static string _GetNonEmptyString(params string?[] values)
+ {
+ foreach (var value in values)
+ if (!string.IsNullOrEmpty(value))
+ return value;
+ return string.Empty;
+ }
}