diff --git a/listeners.go b/listeners.go index 01adc615d..e698e7b7f 100644 --- a/listeners.go +++ b/listeners.go @@ -361,7 +361,7 @@ func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort ui if end < start { return NetworkAddress{}, fmt.Errorf("end port must not be less than start port") } - if (end - start) > maxPortSpan { + if (end-start)+1 > maxPortSpan { return NetworkAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan) } } diff --git a/network_test.go b/network_test.go new file mode 100644 index 000000000..309cb99a9 --- /dev/null +++ b/network_test.go @@ -0,0 +1,963 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package caddy + +import ( + "context" + "fmt" + "net" + "strings" + "sync" + "testing" + "time" +) + +func TestNetworkAddress_String_Consistency(t *testing.T) { + tests := []struct { + name string + addr NetworkAddress + }{ + { + name: "basic tcp", + addr: NetworkAddress{Network: "tcp", Host: "localhost", StartPort: 8080, EndPort: 8080}, + }, + { + name: "tcp with port range", + addr: NetworkAddress{Network: "tcp", Host: "localhost", StartPort: 8080, EndPort: 8090}, + }, + { + name: "unix socket", + addr: NetworkAddress{Network: "unix", Host: "/tmp/socket"}, + }, + { + name: "udp", + addr: NetworkAddress{Network: "udp", Host: "0.0.0.0", StartPort: 53, EndPort: 53}, + }, + { + name: "ipv6", + addr: NetworkAddress{Network: "tcp", Host: "::1", StartPort: 80, EndPort: 80}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + str := test.addr.String() + + // Parse the string back + parsed, err := ParseNetworkAddress(str) + if err != nil { + t.Fatalf("Failed to parse string representation: %v", err) + } + + // Should be equivalent to original + if parsed.Network != test.addr.Network { + t.Errorf("Network mismatch: expected %s, got %s", test.addr.Network, parsed.Network) + } + if parsed.Host != test.addr.Host { + t.Errorf("Host mismatch: expected %s, got %s", test.addr.Host, parsed.Host) + } + if parsed.StartPort != test.addr.StartPort { + t.Errorf("StartPort mismatch: expected %d, got %d", test.addr.StartPort, parsed.StartPort) + } + if parsed.EndPort != test.addr.EndPort { + t.Errorf("EndPort mismatch: expected %d, got %d", test.addr.EndPort, parsed.EndPort) + } + }) + } +} + +func TestNetworkAddress_PortRangeSize_EdgeCases(t *testing.T) { + tests := []struct { + name string + addr NetworkAddress + expected uint + }{ + { + name: "single port", + addr: NetworkAddress{StartPort: 80, EndPort: 80}, + expected: 1, + }, + { + name: "invalid range (end < start)", + addr: NetworkAddress{StartPort: 8080, EndPort: 8070}, + expected: 0, + }, + { + name: "zero ports", + addr: NetworkAddress{StartPort: 0, EndPort: 0}, + expected: 1, + }, + { + name: "maximum range", + addr: NetworkAddress{StartPort: 1, EndPort: 65535}, + expected: 65535, + }, + { + name: "large range", + addr: NetworkAddress{StartPort: 8000, EndPort: 9000}, + expected: 1001, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + size := test.addr.PortRangeSize() + if size != test.expected { + t.Errorf("Expected %d, got %d", test.expected, size) + } + }) + } +} + +func TestNetworkAddress_At_Validation(t *testing.T) { + addr := NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 8080, + EndPort: 8090, + } + + // Test valid offsets + for offset := uint(0); offset <= 10; offset++ { + result := addr.At(offset) + expectedPort := 8080 + offset + + if result.StartPort != expectedPort || result.EndPort != expectedPort { + t.Errorf("Offset %d: expected port %d, got %d-%d", + offset, expectedPort, result.StartPort, result.EndPort) + } + + if result.Network != addr.Network || result.Host != addr.Host { + t.Errorf("Offset %d: network/host should be preserved", offset) + } + } +} + +func TestNetworkAddress_Expand_LargeRange(t *testing.T) { + addr := NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 8000, + EndPort: 8010, + } + + expanded := addr.Expand() + expectedSize := 11 // 8000 to 8010 inclusive + + if len(expanded) != expectedSize { + t.Errorf("Expected %d addresses, got %d", expectedSize, len(expanded)) + } + + // Verify each address + for i, expandedAddr := range expanded { + expectedPort := uint(8000 + i) + if expandedAddr.StartPort != expectedPort || expandedAddr.EndPort != expectedPort { + t.Errorf("Address %d: expected port %d, got %d-%d", + i, expectedPort, expandedAddr.StartPort, expandedAddr.EndPort) + } + } +} + +func TestNetworkAddress_IsLoopback_EdgeCases(t *testing.T) { + tests := []struct { + name string + addr NetworkAddress + expected bool + }{ + { + name: "unix socket", + addr: NetworkAddress{Network: "unix", Host: "/tmp/socket"}, + expected: true, // Unix sockets are always considered loopback + }, + { + name: "fd network", + addr: NetworkAddress{Network: "fd", Host: "3"}, + expected: true, // fd networks are always considered loopback + }, + { + name: "localhost", + addr: NetworkAddress{Network: "tcp", Host: "localhost"}, + expected: true, + }, + { + name: "127.0.0.1", + addr: NetworkAddress{Network: "tcp", Host: "127.0.0.1"}, + expected: true, + }, + { + name: "::1", + addr: NetworkAddress{Network: "tcp", Host: "::1"}, + expected: true, + }, + { + name: "127.0.0.2", + addr: NetworkAddress{Network: "tcp", Host: "127.0.0.2"}, + expected: true, // Part of 127.0.0.0/8 loopback range + }, + { + name: "192.168.1.1", + addr: NetworkAddress{Network: "tcp", Host: "192.168.1.1"}, + expected: false, // Private but not loopback + }, + { + name: "invalid ip", + addr: NetworkAddress{Network: "tcp", Host: "invalid-ip"}, + expected: false, + }, + { + name: "empty host", + addr: NetworkAddress{Network: "tcp", Host: ""}, + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.addr.isLoopback() + if result != test.expected { + t.Errorf("Expected %v, got %v", test.expected, result) + } + }) + } +} + +func TestNetworkAddress_IsWildcard_EdgeCases(t *testing.T) { + tests := []struct { + name string + addr NetworkAddress + expected bool + }{ + { + name: "empty host", + addr: NetworkAddress{Network: "tcp", Host: ""}, + expected: true, + }, + { + name: "ipv4 any", + addr: NetworkAddress{Network: "tcp", Host: "0.0.0.0"}, + expected: true, + }, + { + name: "ipv6 any", + addr: NetworkAddress{Network: "tcp", Host: "::"}, + expected: true, + }, + { + name: "localhost", + addr: NetworkAddress{Network: "tcp", Host: "localhost"}, + expected: false, + }, + { + name: "specific ip", + addr: NetworkAddress{Network: "tcp", Host: "192.168.1.1"}, + expected: false, + }, + { + name: "invalid ip", + addr: NetworkAddress{Network: "tcp", Host: "invalid"}, + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.addr.isWildcardInterface() + if result != test.expected { + t.Errorf("Expected %v, got %v", test.expected, result) + } + }) + } +} + +func TestSplitNetworkAddress_IPv6_EdgeCases(t *testing.T) { + tests := []struct { + name string + input string + expectNetwork string + expectHost string + expectPort string + expectErr bool + }{ + { + name: "ipv6 with port", + input: "[::1]:8080", + expectHost: "::1", + expectPort: "8080", + }, + { + name: "ipv6 without port", + input: "[::1]", + expectHost: "::1", + }, + { + name: "ipv6 without brackets or port", + input: "::1", + expectHost: "::1", + }, + { + name: "ipv6 loopback", + input: "[::1]:443", + expectHost: "::1", + expectPort: "443", + }, + { + name: "ipv6 any address", + input: "[::]:80", + expectHost: "::", + expectPort: "80", + }, + { + name: "ipv6 with network prefix", + input: "tcp6/[::1]:8080", + expectNetwork: "tcp6", + expectHost: "::1", + expectPort: "8080", + }, + { + name: "malformed ipv6", + input: "[::1:8080", // Missing closing bracket + expectHost: "::1:8080", + }, + { + name: "ipv6 with zone", + input: "[fe80::1%eth0]:8080", + expectHost: "fe80::1%eth0", + expectPort: "8080", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + network, host, port, err := SplitNetworkAddress(test.input) + + if test.expectErr && err == nil { + t.Error("Expected error but got none") + } + if !test.expectErr && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if network != test.expectNetwork { + t.Errorf("Network: expected '%s', got '%s'", test.expectNetwork, network) + } + if host != test.expectHost { + t.Errorf("Host: expected '%s', got '%s'", test.expectHost, host) + } + if port != test.expectPort { + t.Errorf("Port: expected '%s', got '%s'", test.expectPort, port) + } + }) + } +} + +func TestParseNetworkAddress_PortRange_Validation(t *testing.T) { + tests := []struct { + name string + input string + expectErr bool + errMsg string + }{ + { + name: "valid range", + input: "localhost:8080-8090", + expectErr: false, + }, + { + name: "inverted range", + input: "localhost:8090-8080", + expectErr: true, + errMsg: "end port must not be less than start port", + }, + { + name: "too large range", + input: "localhost:0-65535", + expectErr: true, + errMsg: "port range exceeds 65535 ports", + }, + { + name: "invalid start port", + input: "localhost:abc-8080", + expectErr: true, + }, + { + name: "invalid end port", + input: "localhost:8080-xyz", + expectErr: true, + }, + { + name: "port too large", + input: "localhost:99999", + expectErr: true, + }, + { + name: "negative port", + input: "localhost:-80", + expectErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := ParseNetworkAddress(test.input) + + if test.expectErr && err == nil { + t.Error("Expected error but got none") + } + if !test.expectErr && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if test.expectErr && test.errMsg != "" && err != nil { + if !containsString(err.Error(), test.errMsg) { + t.Errorf("Expected error containing '%s', got '%s'", test.errMsg, err.Error()) + } + } + }) + } +} + +func TestNetworkAddress_Listen_ContextCancellation(t *testing.T) { + addr := NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 0, // Let OS assign port + EndPort: 0, + } + + // Create context that will be cancelled + ctx, cancel := context.WithCancel(context.Background()) + + // Start listening in a goroutine + listenDone := make(chan error, 1) + go func() { + _, err := addr.Listen(ctx, 0, net.ListenConfig{}) + listenDone <- err + }() + + // Cancel context immediately + cancel() + + // Should get context cancellation error quickly + select { + case err := <-listenDone: + if err == nil { + t.Error("Expected error due to context cancellation") + } + // Accept any error related to context cancellation + // (could be context.Canceled or DNS lookup error due to cancellation) + case <-time.After(time.Second): + t.Error("Listen operation did not respect context cancellation") + } +} + +func TestNetworkAddress_ListenAll_PartialFailure(t *testing.T) { + // Create an address range where some ports might fail to bind + addr := NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 0, // OS-assigned port + EndPort: 2, // Try to bind 3 ports starting from OS-assigned + } + + // This test might be flaky depending on available ports, + // but tests the error handling logic + ctx := context.Background() + + listeners, err := addr.ListenAll(ctx, net.ListenConfig{}) + + // Either all succeed or all fail (due to cleanup on partial failure) + if err != nil { + // If there's an error, no listeners should be returned + if len(listeners) != 0 { + t.Errorf("Expected no listeners on error, got %d", len(listeners)) + } + } else { + // If successful, should have listeners for all ports in range + expectedCount := int(addr.PortRangeSize()) + if len(listeners) != expectedCount { + t.Errorf("Expected %d listeners, got %d", expectedCount, len(listeners)) + } + + // Clean up listeners + for _, ln := range listeners { + if closer, ok := ln.(interface{ Close() error }); ok { + closer.Close() + } + } + } +} + +func TestJoinNetworkAddress_SpecialCases(t *testing.T) { + tests := []struct { + name string + network string + host string + port string + expected string + }{ + { + name: "empty everything", + network: "", + host: "", + port: "", + expected: "", + }, + { + name: "network only", + network: "tcp", + host: "", + port: "", + expected: "tcp/", + }, + { + name: "host only", + network: "", + host: "localhost", + port: "", + expected: "localhost", + }, + { + name: "port only", + network: "", + host: "", + port: "8080", + expected: ":8080", + }, + { + name: "unix socket with port (port ignored)", + network: "unix", + host: "/tmp/socket", + port: "8080", + expected: "unix//tmp/socket", + }, + { + name: "fd network with port (port ignored)", + network: "fd", + host: "3", + port: "8080", + expected: "fd/3", + }, + { + name: "ipv6 host with port", + network: "tcp", + host: "::1", + port: "8080", + expected: "tcp/[::1]:8080", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := JoinNetworkAddress(test.network, test.host, test.port) + if result != test.expected { + t.Errorf("Expected '%s', got '%s'", test.expected, result) + } + }) + } +} + +func TestIsUnixNetwork_IsFdNetwork(t *testing.T) { + tests := []struct { + network string + isUnix bool + isFd bool + }{ + {"unix", true, false}, + {"unixgram", true, false}, + {"unixpacket", true, false}, + {"fd", false, true}, + {"fdgram", false, true}, + {"tcp", false, false}, + {"udp", false, false}, + {"", false, false}, + {"unix-like", true, false}, + {"fd-like", false, true}, + } + + for _, test := range tests { + t.Run(test.network, func(t *testing.T) { + if IsUnixNetwork(test.network) != test.isUnix { + t.Errorf("IsUnixNetwork('%s'): expected %v, got %v", + test.network, test.isUnix, IsUnixNetwork(test.network)) + } + if IsFdNetwork(test.network) != test.isFd { + t.Errorf("IsFdNetwork('%s'): expected %v, got %v", + test.network, test.isFd, IsFdNetwork(test.network)) + } + + // Test NetworkAddress methods too + addr := NetworkAddress{Network: test.network} + if addr.IsUnixNetwork() != test.isUnix { + t.Errorf("NetworkAddress.IsUnixNetwork(): expected %v, got %v", + test.isUnix, addr.IsUnixNetwork()) + } + if addr.IsFdNetwork() != test.isFd { + t.Errorf("NetworkAddress.IsFdNetwork(): expected %v, got %v", + test.isFd, addr.IsFdNetwork()) + } + }) + } +} + +func TestRegisterNetwork_Validation(t *testing.T) { + // Save original state + originalNetworkTypes := make(map[string]ListenerFunc) + for k, v := range networkTypes { + originalNetworkTypes[k] = v + } + defer func() { + // Restore original state + networkTypes = originalNetworkTypes + }() + + mockListener := func(ctx context.Context, network, host, portRange string, portOffset uint, cfg net.ListenConfig) (any, error) { + return nil, nil + } + + // Test reserved network types that should panic + reservedTypes := []string{ + "tcp", "tcp4", "tcp6", + "udp", "udp4", "udp6", + "unix", "unixpacket", "unixgram", + "ip:1", "ip4:1", "ip6:1", + "fd", "fdgram", + } + + for _, networkType := range reservedTypes { + t.Run("reserved_"+networkType, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for reserved network type: %s", networkType) + } + }() + RegisterNetwork(networkType, mockListener) + }) + } + + // Test valid registration + t.Run("valid_registration", func(t *testing.T) { + customNetwork := "custom-network" + RegisterNetwork(customNetwork, mockListener) + + if _, exists := networkTypes[customNetwork]; !exists { + t.Error("Custom network should be registered") + } + }) + + // Test duplicate registration should panic + t.Run("duplicate_registration", func(t *testing.T) { + customNetwork := "another-custom" + RegisterNetwork(customNetwork, mockListener) + + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for duplicate registration") + } + }() + RegisterNetwork(customNetwork, mockListener) + }) +} + +func TestListenerUsage_EdgeCases(t *testing.T) { + // Test ListenerUsage function with various inputs + tests := []struct { + name string + network string + addr string + expected int + }{ + { + name: "non-existent listener", + network: "tcp", + addr: "localhost:9999", + expected: 0, + }, + { + name: "empty network and address", + network: "", + addr: "", + expected: 0, + }, + { + name: "unix socket", + network: "unix", + addr: "/tmp/non-existent.sock", + expected: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + usage := ListenerUsage(test.network, test.addr) + if usage != test.expected { + t.Errorf("Expected usage %d, got %d", test.expected, usage) + } + }) + } +} + +func TestNetworkAddress_Port_Formatting(t *testing.T) { + tests := []struct { + name string + addr NetworkAddress + expected string + }{ + { + name: "single port", + addr: NetworkAddress{StartPort: 80, EndPort: 80}, + expected: "80", + }, + { + name: "port range", + addr: NetworkAddress{StartPort: 8080, EndPort: 8090}, + expected: "8080-8090", + }, + { + name: "zero ports", + addr: NetworkAddress{StartPort: 0, EndPort: 0}, + expected: "0", + }, + { + name: "large ports", + addr: NetworkAddress{StartPort: 65534, EndPort: 65535}, + expected: "65534-65535", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.addr.port() + if result != test.expected { + t.Errorf("Expected '%s', got '%s'", test.expected, result) + } + }) + } +} + +func TestNetworkAddress_JoinHostPort_SpecialNetworks(t *testing.T) { + tests := []struct { + name string + addr NetworkAddress + offset uint + expected string + }{ + { + name: "unix socket ignores offset", + addr: NetworkAddress{ + Network: "unix", + Host: "/tmp/socket", + }, + offset: 100, + expected: "/tmp/socket", + }, + { + name: "fd network ignores offset", + addr: NetworkAddress{ + Network: "fd", + Host: "3", + }, + offset: 50, + expected: "3", + }, + { + name: "tcp with offset", + addr: NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 8000, + }, + offset: 10, + expected: "localhost:8010", + }, + { + name: "ipv6 with offset", + addr: NetworkAddress{ + Network: "tcp", + Host: "::1", + StartPort: 8000, + }, + offset: 5, + expected: "[::1]:8005", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := test.addr.JoinHostPort(test.offset) + if result != test.expected { + t.Errorf("Expected '%s', got '%s'", test.expected, result) + } + }) + } +} + +// Helper function for string containment check +func containsString(haystack, needle string) bool { + return len(haystack) >= len(needle) && + (needle == "" || haystack == needle || + strings.Contains(haystack, needle)) +} + +func TestListenerKey_Generation(t *testing.T) { + tests := []struct { + network string + addr string + expected string + }{ + { + network: "tcp", + addr: "localhost:8080", + expected: "tcp/localhost:8080", + }, + { + network: "unix", + addr: "/tmp/socket", + expected: "unix//tmp/socket", + }, + { + network: "", + addr: "localhost:8080", + expected: "/localhost:8080", + }, + { + network: "tcp", + addr: "", + expected: "tcp/", + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s_%s", test.network, test.addr), func(t *testing.T) { + result := listenerKey(test.network, test.addr) + if result != test.expected { + t.Errorf("Expected '%s', got '%s'", test.expected, result) + } + }) + } +} + +func TestNetworkAddress_ConcurrentAccess(t *testing.T) { + // Test that NetworkAddress methods are safe for concurrent read access + addr := NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 8080, + EndPort: 8090, + } + + const numGoroutines = 50 + var wg sync.WaitGroup + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Call various methods concurrently + _ = addr.String() + _ = addr.PortRangeSize() + _ = addr.IsUnixNetwork() + _ = addr.IsFdNetwork() + _ = addr.isLoopback() + _ = addr.isWildcardInterface() + _ = addr.port() + _ = addr.JoinHostPort(uint(id % 10)) + _ = addr.At(uint(id % 11)) + + // Expand creates new slice, should be safe + expanded := addr.Expand() + if len(expanded) == 0 { + t.Errorf("Goroutine %d: Expected non-empty expansion", id) + } + }(i) + } + + wg.Wait() +} + +func TestNetworkAddress_IPv6_Zone_Handling(t *testing.T) { + // Test IPv6 addresses with zone identifiers + input := "tcp/[fe80::1%eth0]:8080" + + addr, err := ParseNetworkAddress(input) + if err != nil { + t.Fatalf("Failed to parse IPv6 with zone: %v", err) + } + + if addr.Network != "tcp" { + t.Errorf("Expected network 'tcp', got '%s'", addr.Network) + } + if addr.Host != "fe80::1%eth0" { + t.Errorf("Expected host 'fe80::1%%eth0', got '%s'", addr.Host) + } + if addr.StartPort != 8080 { + t.Errorf("Expected port 8080, got %d", addr.StartPort) + } + + // Test string representation round-trip + str := addr.String() + parsed, err := ParseNetworkAddress(str) + if err != nil { + t.Fatalf("Failed to parse string representation: %v", err) + } + + if parsed.Host != addr.Host { + t.Errorf("Round-trip failed: expected host '%s', got '%s'", addr.Host, parsed.Host) + } +} + +func BenchmarkParseNetworkAddress(b *testing.B) { + inputs := []string{ + "localhost:8080", + "tcp/localhost:8080-8090", + "unix//tmp/socket", + "[::1]:443", + "udp/:53", + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + input := inputs[i%len(inputs)] + ParseNetworkAddress(input) + } +} + +func BenchmarkNetworkAddress_String(b *testing.B) { + addr := NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 8080, + EndPort: 8090, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + addr.String() + } +} + +func BenchmarkNetworkAddress_Expand(b *testing.B) { + addr := NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 8000, + EndPort: 8100, // 101 addresses + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + addr.Expand() + } +}