http: centralise effective protocol resolution for HTTPS RR ALPN

This commit is contained in:
Zen Dodd 2026-04-18 11:31:12 +10:00
parent 904f9fddcc
commit 710902ddc3
No known key found for this signature in database
GPG Key ID: 6909546B2C52EC2D
4 changed files with 45 additions and 76 deletions

View File

@ -20,10 +20,8 @@ import (
"crypto/tls"
"errors"
"fmt"
"maps"
"net"
"net/http"
"slices"
"strconv"
"sync"
"time"
@ -236,12 +234,7 @@ func (app *App) Provision(ctx caddy.Context) error {
// if no protocols configured explicitly, enable all except h2c
if len(srv.Protocols) == 0 {
srv.Protocols = slices.Clone(srv.protocolsWithDefaults())
}
srvProtocolsUnique := map[string]struct{}{}
for _, srvProtocol := range srv.Protocols {
srvProtocolsUnique[srvProtocol] = struct{}{}
srv.Protocols = srv.protocolsWithDefaults()
}
if srv.ListenProtocols != nil {
@ -252,31 +245,7 @@ func (app *App) Provision(ctx caddy.Context) error {
for i, lnProtocols := range srv.ListenProtocols {
if lnProtocols != nil {
// populate empty listen protocols with server protocols
lnProtocolsDefault := false
var lnProtocolsInclude []string
srvProtocolsInclude := maps.Clone(srvProtocolsUnique)
// keep existing listener protocols unless they are empty
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" {
lnProtocolsDefault = true
} else {
lnProtocolsInclude = append(lnProtocolsInclude, lnProtocol)
delete(srvProtocolsInclude, lnProtocol)
}
}
// append server protocols to listener protocols if any listener protocols were empty
if lnProtocolsDefault {
for _, srvProtocol := range srv.Protocols {
if _, ok := srvProtocolsInclude[srvProtocol]; ok {
lnProtocolsInclude = append(lnProtocolsInclude, srvProtocol)
}
}
}
srv.ListenProtocols[i] = lnProtocolsInclude
srv.ListenProtocols[i] = srv.listenerProtocolsWithDefaults(lnProtocols)
}
}
}

View File

@ -551,43 +551,14 @@ func (app *App) makeRedirRoute(redirToPort uint, matcherSet MatcherSet) Route {
}
func httpsRRALPNs(srv *Server) []string {
// Automatic HTTPS runs before server provisioning fills in the default
// protocols, so derive the effective set directly from the raw config here.
serverProtocols := srv.protocolsWithDefaults()
protocols := make(map[string]struct{}, len(serverProtocols))
if srv.ListenProtocols == nil {
for _, protocol := range serverProtocols {
protocols[protocol] = struct{}{}
}
} else {
for _, lnProtocols := range srv.ListenProtocols {
if len(lnProtocols) == 0 {
for _, protocol := range serverProtocols {
protocols[protocol] = struct{}{}
}
continue
}
for _, protocol := range lnProtocols {
if protocol == "" {
for _, inherited := range serverProtocols {
protocols[inherited] = struct{}{}
}
continue
}
protocols[protocol] = struct{}{}
}
}
}
alpn := make(map[string]struct{}, 3)
if _, ok := protocols["h3"]; ok {
if srv.protocol("h3") {
alpn["h3"] = struct{}{}
}
if _, ok := protocols["h2"]; ok {
if srv.protocol("h2") {
alpn["h2"] = struct{}{}
}
if _, ok := protocols["h1"]; ok {
if srv.protocol("h1") {
alpn["http/1.1"] = struct{}{}
}
return caddytls.OrderedHTTPSRRALPN(alpn)

View File

@ -22,7 +22,8 @@ func TestHTTPSRRALPNsListenProtocolOverrides(t *testing.T) {
ListenProtocols: [][]string{
{"h1"},
nil,
{"h2c", "h3"},
{},
{"h3", ""},
},
}

View File

@ -902,18 +902,13 @@ func (s *Server) logRequest(
// protocol returns true if the protocol proto is configured/enabled.
func (s *Server) protocol(proto string) bool {
if s.ListenProtocols == nil {
if slices.Contains(s.protocolsWithDefaults(), proto) {
return slices.Contains(s.protocolsWithDefaults(), proto)
}
for _, lnProtocols := range s.ListenProtocols {
if slices.Contains(s.listenerProtocolsWithDefaults(lnProtocols), proto) {
return true
}
} else {
serverProtocols := s.protocolsWithDefaults()
for _, lnProtocols := range s.ListenProtocols {
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" && slices.Contains(serverProtocols, proto) || lnProtocol == proto {
return true
}
}
}
}
return false
@ -926,6 +921,39 @@ func (s *Server) protocolsWithDefaults() []string {
return s.Protocols
}
func (s *Server) listenerProtocolsWithDefaults(lnProtocols []string) []string {
serverProtocols := s.protocolsWithDefaults()
if len(lnProtocols) == 0 {
return serverProtocols
}
lnProtocolsDefault := false
lnProtocolsInclude := make([]string, 0, len(lnProtocols)+len(serverProtocols))
srvProtocolsInclude := make(map[string]struct{}, len(serverProtocols))
for _, srvProtocol := range serverProtocols {
srvProtocolsInclude[srvProtocol] = struct{}{}
}
for _, lnProtocol := range lnProtocols {
if lnProtocol == "" {
lnProtocolsDefault = true
continue
}
lnProtocolsInclude = append(lnProtocolsInclude, lnProtocol)
delete(srvProtocolsInclude, lnProtocol)
}
if lnProtocolsDefault {
for _, srvProtocol := range serverProtocols {
if _, ok := srvProtocolsInclude[srvProtocol]; ok {
lnProtocolsInclude = append(lnProtocolsInclude, srvProtocol)
}
}
}
return lnProtocolsInclude
}
// Listeners returns the server's listeners. These are active listeners,
// so calling Accept() or Close() on them will probably break things.
// They are made available here for read-only purposes (e.g. Addr())