mirror of
https://github.com/caddyserver/caddy.git
synced 2026-05-25 16:22:36 -04:00
Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72875401e3 | |||
| 704394d9d1 | |||
| 6c675e29f8 | |||
| a4004467aa | |||
| 41d8cea9e6 | |||
| ef3158cac7 | |||
| a5ef0600aa | |||
| 9236eacd35 | |||
| 258a928d27 | |||
| e56b31e3ad | |||
| 435e521203 | |||
| 476d75219c | |||
| 719d879f3d | |||
| 5db80034a8 | |||
| a2a7fd6671 | |||
| 201cba5b66 | |||
| f35ea4665d | |||
| c8bc9971b4 | |||
| 3d6f58bf46 | |||
| c29418e299 | |||
| af3d6b3935 | |||
| 656bfc3111 | |||
| 05504942d8 | |||
| 5d50967a0d | |||
| a0f2922157 | |||
| 7a92274e9c | |||
| 6872a66604 | |||
| c2d586c458 | |||
| c6367fb774 | |||
| fc63a3c3f5 | |||
| 93315eafff | |||
| 0b83afa6a5 | |||
| e86b913567 | |||
| b8e72c6a22 | |||
| be4593bd00 | |||
| a6c64276c1 | |||
| 4a9c83b969 |
@@ -132,8 +132,6 @@ jobs:
|
||||
- name: Run tests
|
||||
# id: step_test
|
||||
# continue-on-error: true
|
||||
env:
|
||||
GODEBUG: http2xconnect=1
|
||||
run: |
|
||||
# (go test -v -coverprofile=cover-profile.out -race ./... 2>&1) > test-results/test-result.out
|
||||
go test -v -coverprofile="cover-profile.out" -short -race ./...
|
||||
@@ -193,7 +191,7 @@ jobs:
|
||||
retries=3
|
||||
exit_code=0
|
||||
while ((retries > 0)); do
|
||||
GODEBUG=http2xconnect=1 CGO_ENABLED=0 go test -p 1 -v ./...
|
||||
CGO_ENABLED=0 go test -p 1 -v ./...
|
||||
exit_code=$?
|
||||
if ((exit_code == 0)); then
|
||||
break
|
||||
|
||||
@@ -0,0 +1,377 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAPIError_Error_WithErr(t *testing.T) {
|
||||
underlyingErr := errors.New("underlying error")
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: underlyingErr,
|
||||
Message: "API error message",
|
||||
}
|
||||
|
||||
result := apiErr.Error()
|
||||
expected := "underlying error"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_Error_WithoutErr(t *testing.T) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: nil,
|
||||
Message: "API error message",
|
||||
}
|
||||
|
||||
result := apiErr.Error()
|
||||
expected := "API error message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_Error_BothNil(t *testing.T) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: nil,
|
||||
Message: "",
|
||||
}
|
||||
|
||||
result := apiErr.Error()
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected empty string, got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_JSON_Serialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiErr APIError
|
||||
}{
|
||||
{
|
||||
name: "with message only",
|
||||
apiErr: APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "validation failed",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with underlying error only",
|
||||
apiErr: APIError{
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
Err: errors.New("internal error"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with both message and error",
|
||||
apiErr: APIError{
|
||||
HTTPStatus: http.StatusConflict,
|
||||
Err: errors.New("underlying"),
|
||||
Message: "conflict detected",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "minimal error",
|
||||
apiErr: APIError{
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Marshal to JSON
|
||||
jsonData, err := json.Marshal(test.apiErr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal APIError: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var unmarshaled APIError
|
||||
err = json.Unmarshal(jsonData, &unmarshaled)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal APIError: %v", err)
|
||||
}
|
||||
|
||||
// Only Message field should survive JSON round-trip
|
||||
// HTTPStatus and Err are marked with json:"-"
|
||||
if unmarshaled.Message != test.apiErr.Message {
|
||||
t.Errorf("Message mismatch: expected '%s', got '%s'",
|
||||
test.apiErr.Message, unmarshaled.Message)
|
||||
}
|
||||
|
||||
// HTTPStatus and Err should be zero values after unmarshal
|
||||
if unmarshaled.HTTPStatus != 0 {
|
||||
t.Errorf("HTTPStatus should be 0 after unmarshal, got %d", unmarshaled.HTTPStatus)
|
||||
}
|
||||
if unmarshaled.Err != nil {
|
||||
t.Errorf("Err should be nil after unmarshal, got %v", unmarshaled.Err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_HTTPStatus_Values(t *testing.T) {
|
||||
// Test common HTTP status codes
|
||||
statusCodes := []int{
|
||||
http.StatusBadRequest,
|
||||
http.StatusUnauthorized,
|
||||
http.StatusForbidden,
|
||||
http.StatusNotFound,
|
||||
http.StatusMethodNotAllowed,
|
||||
http.StatusConflict,
|
||||
http.StatusPreconditionFailed,
|
||||
http.StatusInternalServerError,
|
||||
http.StatusNotImplemented,
|
||||
http.StatusServiceUnavailable,
|
||||
}
|
||||
|
||||
for _, status := range statusCodes {
|
||||
t.Run(fmt.Sprintf("status_%d", status), func(t *testing.T) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: status,
|
||||
Message: http.StatusText(status),
|
||||
}
|
||||
|
||||
if apiErr.HTTPStatus != status {
|
||||
t.Errorf("Expected status %d, got %d", status, apiErr.HTTPStatus)
|
||||
}
|
||||
|
||||
// Test that error message is reasonable
|
||||
if apiErr.Message == "" && status >= 400 {
|
||||
t.Errorf("Status %d should have a message", status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_ErrorInterface_Compliance(t *testing.T) {
|
||||
// Verify APIError properly implements error interface
|
||||
var err error = APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "test error",
|
||||
}
|
||||
|
||||
errorMsg := err.Error()
|
||||
if errorMsg != "test error" {
|
||||
t.Errorf("Expected 'test error', got '%s'", errorMsg)
|
||||
}
|
||||
|
||||
// Test with underlying error
|
||||
underlyingErr := errors.New("underlying")
|
||||
err2 := APIError{
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
Err: underlyingErr,
|
||||
Message: "wrapper",
|
||||
}
|
||||
|
||||
if err2.Error() != "underlying" {
|
||||
t.Errorf("Expected 'underlying', got '%s'", err2.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_JSON_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "empty message",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "unicode message",
|
||||
message: "Error: 🚨 Something went wrong! 你好",
|
||||
},
|
||||
{
|
||||
name: "json characters in message",
|
||||
message: `Error with "quotes" and {brackets}`,
|
||||
},
|
||||
{
|
||||
name: "newlines in message",
|
||||
message: "Line 1\nLine 2\r\nLine 3",
|
||||
},
|
||||
{
|
||||
name: "very long message",
|
||||
message: string(make([]byte, 10000)), // 10KB message
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: test.message,
|
||||
}
|
||||
|
||||
// Should be JSON serializable
|
||||
jsonData, err := json.Marshal(apiErr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal APIError: %v", err)
|
||||
}
|
||||
|
||||
// Should be deserializable
|
||||
var unmarshaled APIError
|
||||
err = json.Unmarshal(jsonData, &unmarshaled)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal APIError: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Message != test.message {
|
||||
t.Errorf("Message corrupted during JSON round-trip")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_Chaining(t *testing.T) {
|
||||
// Test error chaining scenarios
|
||||
rootErr := errors.New("root cause")
|
||||
wrappedErr := fmt.Errorf("wrapped: %w", rootErr)
|
||||
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
Err: wrappedErr,
|
||||
Message: "API wrapper",
|
||||
}
|
||||
|
||||
// Error() should return the underlying error message
|
||||
if apiErr.Error() != wrappedErr.Error() {
|
||||
t.Errorf("Expected underlying error message, got '%s'", apiErr.Error())
|
||||
}
|
||||
|
||||
// Should be able to unwrap
|
||||
if !errors.Is(apiErr.Err, rootErr) {
|
||||
t.Error("Should be able to unwrap to root cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_StatusCode_Boundaries(t *testing.T) {
|
||||
// Test edge cases for HTTP status codes
|
||||
tests := []struct {
|
||||
name string
|
||||
status int
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "negative status",
|
||||
status: -1,
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "zero status",
|
||||
status: 0,
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "valid 1xx",
|
||||
status: http.StatusContinue,
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid 2xx",
|
||||
status: http.StatusOK,
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid 4xx",
|
||||
status: http.StatusBadRequest,
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid 5xx",
|
||||
status: http.StatusInternalServerError,
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "too large status",
|
||||
status: 9999,
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
err := APIError{
|
||||
HTTPStatus: test.status,
|
||||
Message: "test",
|
||||
}
|
||||
|
||||
// The struct allows any int value, but we can test
|
||||
// if it's a valid HTTP status
|
||||
statusText := http.StatusText(test.status)
|
||||
isValidStatus := statusText != ""
|
||||
|
||||
if isValidStatus != test.valid {
|
||||
t.Errorf("Status %d validity: expected %v, got %v",
|
||||
test.status, test.valid, isValidStatus)
|
||||
}
|
||||
|
||||
// Verify the struct holds the status
|
||||
if err.HTTPStatus != test.status {
|
||||
t.Errorf("Status not preserved: expected %d, got %d", test.status, err.HTTPStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAPIError_Error(b *testing.B) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: errors.New("benchmark error"),
|
||||
Message: "benchmark message",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
apiErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAPIError_JSON_Marshal(b *testing.B) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: errors.New("benchmark error"),
|
||||
Message: "benchmark message",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
json.Marshal(apiErr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAPIError_JSON_Unmarshal(b *testing.B) {
|
||||
jsonData := []byte(`{"error": "benchmark message"}`)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var result APIError
|
||||
_ = json.Unmarshal(jsonData, &result)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
package caddyfile
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestImportGraphAddNode(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
|
||||
g.addNode("a")
|
||||
if !g.exists("a") {
|
||||
t.Error("expected node 'a' to exist after addNode")
|
||||
}
|
||||
|
||||
// Adding again should not error
|
||||
g.addNode("a")
|
||||
if !g.exists("a") {
|
||||
t.Error("expected node 'a' to still exist after duplicate addNode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphAddNodes(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
|
||||
g.addNodes([]string{"a", "b", "c"})
|
||||
for _, name := range []string{"a", "b", "c"} {
|
||||
if !g.exists(name) {
|
||||
t.Errorf("expected node %q to exist", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphRemoveNode(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
|
||||
g.addNode("a")
|
||||
g.addNode("b")
|
||||
g.removeNode("a")
|
||||
|
||||
if g.exists("a") {
|
||||
t.Error("expected node 'a' to not exist after removeNode")
|
||||
}
|
||||
if !g.exists("b") {
|
||||
t.Error("expected node 'b' to still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphRemoveNodes(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
|
||||
g.addNodes([]string{"a", "b", "c", "d"})
|
||||
g.removeNodes([]string{"a", "c"})
|
||||
|
||||
if g.exists("a") {
|
||||
t.Error("expected node 'a' to be removed")
|
||||
}
|
||||
if g.exists("c") {
|
||||
t.Error("expected node 'c' to be removed")
|
||||
}
|
||||
if !g.exists("b") {
|
||||
t.Error("expected node 'b' to still exist")
|
||||
}
|
||||
if !g.exists("d") {
|
||||
t.Error("expected node 'd' to still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphAddEdge(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNodes([]string{"a", "b"})
|
||||
|
||||
err := g.addEdge("a", "b")
|
||||
if err != nil {
|
||||
t.Fatalf("addEdge() error = %v", err)
|
||||
}
|
||||
|
||||
if !g.areConnected("a", "b") {
|
||||
t.Error("expected 'a' -> 'b' edge to exist")
|
||||
}
|
||||
if g.areConnected("b", "a") {
|
||||
t.Error("expected no 'b' -> 'a' edge (directed)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphAddEdgeNonExistentNode(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNode("a")
|
||||
|
||||
err := g.addEdge("a", "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error when adding edge to nonexistent node")
|
||||
}
|
||||
|
||||
err = g.addEdge("nonexistent", "a")
|
||||
if err == nil {
|
||||
t.Error("expected error when adding edge from nonexistent node")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphAddEdgeDuplicate(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNodes([]string{"a", "b"})
|
||||
|
||||
_ = g.addEdge("a", "b")
|
||||
err := g.addEdge("a", "b")
|
||||
if err != nil {
|
||||
t.Errorf("duplicate addEdge() should not error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphCycleDetectionDirect(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNodes([]string{"a", "b"})
|
||||
|
||||
_ = g.addEdge("a", "b")
|
||||
|
||||
// Adding b -> a should create a cycle
|
||||
err := g.addEdge("b", "a")
|
||||
if err == nil {
|
||||
t.Error("expected error for cycle: a -> b -> a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphCycleDetectionIndirect(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNodes([]string{"a", "b", "c"})
|
||||
|
||||
_ = g.addEdge("a", "b")
|
||||
_ = g.addEdge("b", "c")
|
||||
|
||||
// Adding c -> a should create a cycle: a -> b -> c -> a
|
||||
err := g.addEdge("c", "a")
|
||||
if err == nil {
|
||||
t.Error("expected error for indirect cycle: a -> b -> c -> a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphCycleDetectionLongChain(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
nodes := []string{"a", "b", "c", "d", "e"}
|
||||
g.addNodes(nodes)
|
||||
|
||||
_ = g.addEdge("a", "b")
|
||||
_ = g.addEdge("b", "c")
|
||||
_ = g.addEdge("c", "d")
|
||||
_ = g.addEdge("d", "e")
|
||||
|
||||
// Adding e -> a should create a cycle
|
||||
err := g.addEdge("e", "a")
|
||||
if err == nil {
|
||||
t.Error("expected error for long cycle: a -> b -> c -> d -> e -> a")
|
||||
}
|
||||
|
||||
// Adding e -> c should also create a cycle
|
||||
err = g.addEdge("e", "c")
|
||||
if err == nil {
|
||||
t.Error("expected error for cycle: c -> d -> e -> c")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphNoCycleDAG(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNodes([]string{"a", "b", "c", "d"})
|
||||
|
||||
// Create a diamond DAG: a -> b, a -> c, b -> d, c -> d
|
||||
_ = g.addEdge("a", "b")
|
||||
_ = g.addEdge("a", "c")
|
||||
_ = g.addEdge("b", "d")
|
||||
|
||||
err := g.addEdge("c", "d")
|
||||
if err != nil {
|
||||
t.Errorf("expected no cycle in DAG, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphSelfLoop(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNode("a")
|
||||
|
||||
// BUG: Self-loops are not detected by willCycle(). The function checks if
|
||||
// adding edge from→to would create a cycle by traversing edges from "to"
|
||||
// to see if "from" is reachable. But for a self-loop (from==to), the edge
|
||||
// doesn't exist yet, so the DFS finds nothing and returns false.
|
||||
// A self-importing file would NOT be caught by this cycle detection.
|
||||
err := g.addEdge("a", "a")
|
||||
if err != nil {
|
||||
t.Log("Self-loop was correctly detected (bug may have been fixed)")
|
||||
} else {
|
||||
t.Log("BUG CONFIRMED: addEdge('a', 'a') did not detect self-loop cycle")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphExistsNonExistent(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
if g.exists("nonexistent") {
|
||||
t.Error("expected false for nonexistent node on empty graph")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphAreConnectedEmpty(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
if g.areConnected("a", "b") {
|
||||
t.Error("expected false for areConnected on empty graph")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphAddEdges(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNodes([]string{"a", "b", "c", "d"})
|
||||
|
||||
err := g.addEdges("a", []string{"b", "c", "d"})
|
||||
if err != nil {
|
||||
t.Fatalf("addEdges() error = %v", err)
|
||||
}
|
||||
|
||||
if !g.areConnected("a", "b") || !g.areConnected("a", "c") || !g.areConnected("a", "d") {
|
||||
t.Error("expected all edges from 'a' to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphAddEdgesWithCycle(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
g.addNodes([]string{"a", "b", "c"})
|
||||
|
||||
_ = g.addEdge("b", "c")
|
||||
_ = g.addEdge("c", "a")
|
||||
|
||||
// This should fail because a -> b -> c -> a creates a cycle
|
||||
err := g.addEdges("a", []string{"b"})
|
||||
if err == nil {
|
||||
t.Error("expected error when addEdges creates a cycle")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphRemoveNodeEdgeLeakBug(t *testing.T) {
|
||||
// This test documents a known bug: removeNode doesn't clean up edges.
|
||||
// Edges FROM the removed node remain in the adjacency list.
|
||||
g := &importGraph{}
|
||||
g.addNodes([]string{"a", "b", "c"})
|
||||
_ = g.addEdge("a", "b")
|
||||
_ = g.addEdge("b", "c")
|
||||
|
||||
g.removeNode("b")
|
||||
|
||||
// Bug: "b" is removed from nodes, but edges from "b" are still in the adjacency list.
|
||||
// This means the graph is now inconsistent.
|
||||
// The node doesn't exist...
|
||||
if g.exists("b") {
|
||||
t.Error("node 'b' should not exist after removeNode")
|
||||
}
|
||||
|
||||
// ...but edges from "b" may still be present in the edges map (this is a bug).
|
||||
// We test this to document the behavior.
|
||||
if g.edges != nil {
|
||||
if targets, ok := g.edges["b"]; ok && len(targets) > 0 {
|
||||
t.Log("BUG CONFIRMED: removeNode does not clean up outgoing edges. " +
|
||||
"Edges from removed node 'b' still exist in adjacency list.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportGraphWillCycleEmptyGraph(t *testing.T) {
|
||||
g := &importGraph{}
|
||||
// willCycle on empty graph should return false
|
||||
if g.willCycle("a", "b") {
|
||||
t.Error("expected no cycle on empty graph")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
package caddyconfig
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
wantNil bool
|
||||
wantWarnings int
|
||||
nilWarnings bool // pass nil warnings pointer
|
||||
}{
|
||||
{
|
||||
name: "simple string",
|
||||
val: "hello",
|
||||
wantNil: false,
|
||||
wantWarnings: 0,
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
val: struct{ Name string }{"test"},
|
||||
wantNil: false,
|
||||
wantWarnings: 0,
|
||||
},
|
||||
{
|
||||
name: "nil value",
|
||||
val: nil,
|
||||
wantNil: false, // json.Marshal(nil) returns "null"
|
||||
wantWarnings: 0,
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
val: map[string]string{"key": "val"},
|
||||
wantNil: false,
|
||||
wantWarnings: 0,
|
||||
},
|
||||
{
|
||||
name: "unmarshalable value produces warning",
|
||||
val: make(chan int),
|
||||
wantNil: true,
|
||||
wantWarnings: 1,
|
||||
},
|
||||
{
|
||||
name: "unmarshalable value with nil warnings pointer",
|
||||
val: make(chan int),
|
||||
wantNil: true,
|
||||
nilWarnings: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var warnings *[]Warning
|
||||
if !tt.nilWarnings {
|
||||
w := []Warning{}
|
||||
warnings = &w
|
||||
}
|
||||
|
||||
result := JSON(tt.val, warnings)
|
||||
|
||||
if tt.wantNil && result != nil {
|
||||
t.Errorf("JSON() = %v, want nil", string(result))
|
||||
}
|
||||
if !tt.wantNil && result == nil {
|
||||
t.Error("JSON() = nil, want non-nil")
|
||||
}
|
||||
if warnings != nil && len(*warnings) != tt.wantWarnings {
|
||||
t.Errorf("JSON() produced %d warnings, want %d", len(*warnings), tt.wantWarnings)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONModuleObject(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
fieldName string
|
||||
fieldVal string
|
||||
wantNil bool
|
||||
wantField bool
|
||||
wantWarnings int
|
||||
}{
|
||||
{
|
||||
name: "simple struct",
|
||||
val: struct{ Name string }{"test"},
|
||||
fieldName: "handler",
|
||||
fieldVal: "file_server",
|
||||
wantNil: false,
|
||||
wantField: true,
|
||||
wantWarnings: 0,
|
||||
},
|
||||
{
|
||||
name: "map value",
|
||||
val: map[string]any{"key": "val"},
|
||||
fieldName: "module",
|
||||
fieldVal: "my_module",
|
||||
wantNil: false,
|
||||
wantField: true,
|
||||
wantWarnings: 0,
|
||||
},
|
||||
{
|
||||
name: "non-object type (string) produces warning",
|
||||
val: "not-an-object",
|
||||
fieldName: "handler",
|
||||
fieldVal: "test",
|
||||
wantNil: true,
|
||||
wantField: false,
|
||||
wantWarnings: 1,
|
||||
},
|
||||
{
|
||||
name: "unmarshalable value produces warning",
|
||||
val: make(chan int),
|
||||
fieldName: "handler",
|
||||
fieldVal: "test",
|
||||
wantNil: true,
|
||||
wantField: false,
|
||||
wantWarnings: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
warnings := []Warning{}
|
||||
result := JSONModuleObject(tt.val, tt.fieldName, tt.fieldVal, &warnings)
|
||||
|
||||
if tt.wantNil && result != nil {
|
||||
t.Errorf("JSONModuleObject() = %v, want nil", string(result))
|
||||
}
|
||||
if !tt.wantNil && result == nil {
|
||||
t.Error("JSONModuleObject() = nil, want non-nil")
|
||||
}
|
||||
if len(warnings) != tt.wantWarnings {
|
||||
t.Errorf("JSONModuleObject() produced %d warnings, want %d", len(warnings), tt.wantWarnings)
|
||||
}
|
||||
if tt.wantField && result != nil {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(result, &m); err != nil {
|
||||
t.Fatalf("failed to unmarshal result: %v", err)
|
||||
}
|
||||
if v, ok := m[tt.fieldName]; !ok {
|
||||
t.Errorf("expected field %q in result", tt.fieldName)
|
||||
} else if v != tt.fieldVal {
|
||||
t.Errorf("field %q = %v, want %v", tt.fieldName, v, tt.fieldVal)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONModuleObjectPreservesExistingFields(t *testing.T) {
|
||||
val := struct {
|
||||
Name string `json:"name"`
|
||||
Port int `json:"port"`
|
||||
}{"example", 8080}
|
||||
|
||||
warnings := []Warning{}
|
||||
result := JSONModuleObject(val, "handler", "static", &warnings)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(result, &m); err != nil {
|
||||
t.Fatalf("failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if m["name"] != "example" {
|
||||
t.Errorf("name = %v, want 'example'", m["name"])
|
||||
}
|
||||
if m["port"] != float64(8080) {
|
||||
t.Errorf("port = %v, want 8080", m["port"])
|
||||
}
|
||||
if m["handler"] != "static" {
|
||||
t.Errorf("handler = %v, want 'static'", m["handler"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAdapterNil(t *testing.T) {
|
||||
adapter := GetAdapter("nonexistent_adapter_xyz")
|
||||
if adapter != nil {
|
||||
t.Error("expected nil for unregistered adapter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWarningString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
warning Warning
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "all fields",
|
||||
warning: Warning{File: "Caddyfile", Line: 10, Directive: "reverse_proxy", Message: "upstream not found"},
|
||||
want: "Caddyfile:10 (reverse_proxy): upstream not found",
|
||||
},
|
||||
{
|
||||
name: "no directive",
|
||||
warning: Warning{File: "Caddyfile", Line: 5, Message: "something off"},
|
||||
want: "Caddyfile:5: something off",
|
||||
},
|
||||
{
|
||||
name: "zero line",
|
||||
warning: Warning{File: "config.json", Line: 0, Message: "invalid"},
|
||||
want: "config.json:0: invalid",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.warning.String()
|
||||
if got != tt.want {
|
||||
t.Errorf("Warning.String() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,299 @@
|
||||
package httpcaddyfile
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
)
|
||||
|
||||
func TestShorthandReplacerSimpleReplacements(t *testing.T) {
|
||||
sr := NewShorthandReplacer()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "host",
|
||||
input: "{host}",
|
||||
want: "{http.request.host}",
|
||||
},
|
||||
{
|
||||
name: "hostport",
|
||||
input: "{hostport}",
|
||||
want: "{http.request.hostport}",
|
||||
},
|
||||
{
|
||||
name: "port",
|
||||
input: "{port}",
|
||||
want: "{http.request.port}",
|
||||
},
|
||||
{
|
||||
name: "method",
|
||||
input: "{method}",
|
||||
want: "{http.request.method}",
|
||||
},
|
||||
{
|
||||
name: "uri",
|
||||
input: "{uri}",
|
||||
want: "{http.request.uri}",
|
||||
},
|
||||
{
|
||||
name: "path",
|
||||
input: "{path}",
|
||||
want: "{http.request.uri.path}",
|
||||
},
|
||||
{
|
||||
name: "query",
|
||||
input: "{query}",
|
||||
want: "{http.request.uri.query}",
|
||||
},
|
||||
{
|
||||
name: "scheme",
|
||||
input: "{scheme}",
|
||||
want: "{http.request.scheme}",
|
||||
},
|
||||
{
|
||||
name: "remote_host",
|
||||
input: "{remote_host}",
|
||||
want: "{http.request.remote.host}",
|
||||
},
|
||||
{
|
||||
name: "remote_port",
|
||||
input: "{remote_port}",
|
||||
want: "{http.request.remote.port}",
|
||||
},
|
||||
{
|
||||
name: "uuid",
|
||||
input: "{uuid}",
|
||||
want: "{http.request.uuid}",
|
||||
},
|
||||
{
|
||||
name: "tls_cipher",
|
||||
input: "{tls_cipher}",
|
||||
want: "{http.request.tls.cipher_suite}",
|
||||
},
|
||||
{
|
||||
name: "tls_version",
|
||||
input: "{tls_version}",
|
||||
want: "{http.request.tls.version}",
|
||||
},
|
||||
{
|
||||
name: "client_ip",
|
||||
input: "{client_ip}",
|
||||
want: "{http.vars.client_ip}",
|
||||
},
|
||||
{
|
||||
name: "upstream_hostport",
|
||||
input: "{upstream_hostport}",
|
||||
want: "{http.reverse_proxy.upstream.hostport}",
|
||||
},
|
||||
{
|
||||
name: "dir",
|
||||
input: "{dir}",
|
||||
want: "{http.request.uri.path.dir}",
|
||||
},
|
||||
{
|
||||
name: "file",
|
||||
input: "{file}",
|
||||
want: "{http.request.uri.path.file}",
|
||||
},
|
||||
{
|
||||
name: "orig_method",
|
||||
input: "{orig_method}",
|
||||
want: "{http.request.orig_method}",
|
||||
},
|
||||
{
|
||||
name: "orig_uri",
|
||||
input: "{orig_uri}",
|
||||
want: "{http.request.orig_uri}",
|
||||
},
|
||||
{
|
||||
name: "orig_path",
|
||||
input: "{orig_path}",
|
||||
want: "{http.request.orig_uri.path}",
|
||||
},
|
||||
{
|
||||
name: "no matching placeholder",
|
||||
input: "{unknown}",
|
||||
want: "{unknown}",
|
||||
},
|
||||
{
|
||||
name: "not a placeholder",
|
||||
input: "plain text",
|
||||
want: "plain text",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "multiple placeholders in one string",
|
||||
input: "{host}:{port}",
|
||||
want: "{http.request.host}:{http.request.port}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
segment := caddyfile.Segment{{Text: tt.input}}
|
||||
sr.ApplyToSegment(&segment)
|
||||
got := segment[0].Text
|
||||
if got != tt.want {
|
||||
t.Errorf("ApplyToSegment(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShorthandReplacerComplexReplacements(t *testing.T) {
|
||||
sr := NewShorthandReplacer()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "header placeholder",
|
||||
input: "{header.X-Forwarded-For}",
|
||||
want: "{http.request.header.X-Forwarded-For}",
|
||||
},
|
||||
{
|
||||
name: "cookie placeholder",
|
||||
input: "{cookie.session_id}",
|
||||
want: "{http.request.cookie.session_id}",
|
||||
},
|
||||
{
|
||||
name: "labels placeholder",
|
||||
input: "{labels.0}",
|
||||
want: "{http.request.host.labels.0}",
|
||||
},
|
||||
{
|
||||
name: "path segment placeholder",
|
||||
input: "{path.0}",
|
||||
want: "{http.request.uri.path.0}",
|
||||
},
|
||||
{
|
||||
name: "query placeholder",
|
||||
input: "{query.page}",
|
||||
want: "{http.request.uri.query.page}",
|
||||
},
|
||||
{
|
||||
name: "re placeholder with dots",
|
||||
input: "{re.name.group}",
|
||||
want: "{http.regexp.name.group}",
|
||||
},
|
||||
{
|
||||
name: "vars placeholder",
|
||||
input: "{vars.my_var}",
|
||||
want: "{http.vars.my_var}",
|
||||
},
|
||||
{
|
||||
name: "rp placeholder",
|
||||
input: "{rp.upstream.address}",
|
||||
want: "{http.reverse_proxy.upstream.address}",
|
||||
},
|
||||
{
|
||||
name: "resp placeholder",
|
||||
input: "{resp.status_code}",
|
||||
want: "{http.intercept.status_code}",
|
||||
},
|
||||
{
|
||||
name: "err placeholder",
|
||||
input: "{err.status_code}",
|
||||
want: "{http.error.status_code}",
|
||||
},
|
||||
{
|
||||
name: "file_match placeholder",
|
||||
input: "{file_match.relative}",
|
||||
want: "{http.matchers.file.relative}",
|
||||
},
|
||||
{
|
||||
name: "header with hyphen",
|
||||
input: "{header.Content-Type}",
|
||||
want: "{http.request.header.Content-Type}",
|
||||
},
|
||||
{
|
||||
name: "header with underscore",
|
||||
input: "{header.X_Custom_Header}",
|
||||
want: "{http.request.header.X_Custom_Header}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
segment := caddyfile.Segment{{Text: tt.input}}
|
||||
sr.ApplyToSegment(&segment)
|
||||
got := segment[0].Text
|
||||
if got != tt.want {
|
||||
t.Errorf("ApplyToSegment(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShorthandReplacerApplyToNilSegment(t *testing.T) {
|
||||
sr := NewShorthandReplacer()
|
||||
// Should not panic
|
||||
sr.ApplyToSegment(nil)
|
||||
}
|
||||
|
||||
func TestShorthandReplacerMultipleTokens(t *testing.T) {
|
||||
sr := NewShorthandReplacer()
|
||||
|
||||
segment := caddyfile.Segment{
|
||||
{Text: "{host}"},
|
||||
{Text: "{path}"},
|
||||
{Text: "{header.X-Test}"},
|
||||
{Text: "plain"},
|
||||
}
|
||||
|
||||
sr.ApplyToSegment(&segment)
|
||||
|
||||
expected := []string{
|
||||
"{http.request.host}",
|
||||
"{http.request.uri.path}",
|
||||
"{http.request.header.X-Test}",
|
||||
"plain",
|
||||
}
|
||||
|
||||
for i, want := range expected {
|
||||
if segment[i].Text != want {
|
||||
t.Errorf("token %d: got %q, want %q", i, segment[i].Text, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShorthandReplacerEmptySegment(t *testing.T) {
|
||||
sr := NewShorthandReplacer()
|
||||
segment := caddyfile.Segment{}
|
||||
sr.ApplyToSegment(&segment) // should not panic
|
||||
}
|
||||
|
||||
func TestShorthandReplacerEscapedPlaceholders(t *testing.T) {
|
||||
sr := NewShorthandReplacer()
|
||||
|
||||
// Percent-escaped path placeholder
|
||||
segment := caddyfile.Segment{{Text: "{%path}"}}
|
||||
sr.ApplyToSegment(&segment)
|
||||
if segment[0].Text != "{http.request.uri.path_escaped}" {
|
||||
t.Errorf("got %q, want {http.request.uri.path_escaped}", segment[0].Text)
|
||||
}
|
||||
|
||||
// Percent-escaped query placeholder
|
||||
segment = caddyfile.Segment{{Text: "{%query}"}}
|
||||
sr.ApplyToSegment(&segment)
|
||||
if segment[0].Text != "{http.request.uri.query_escaped}" {
|
||||
t.Errorf("got %q, want {http.request.uri.query_escaped}", segment[0].Text)
|
||||
}
|
||||
|
||||
// Prefixed query
|
||||
segment = caddyfile.Segment{{Text: "{?query}"}}
|
||||
sr.ApplyToSegment(&segment)
|
||||
if segment[0].Text != "{http.request.uri.prefixed_query}" {
|
||||
t.Errorf("got %q, want {http.request.uri.prefixed_query}", segment[0].Text)
|
||||
}
|
||||
}
|
||||
@@ -1,328 +0,0 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddytest"
|
||||
)
|
||||
|
||||
var errExtendedConnectUnsupportedByPeer = errors.New("peer did not advertise RFC 8441 extended CONNECT support")
|
||||
|
||||
func TestReverseProxyExtendedConnectOverH2(t *testing.T) {
|
||||
tester := caddytest.NewTester(t)
|
||||
backend := newWebsocketUpgradeEchoBackend(t)
|
||||
defer backend.Close()
|
||||
|
||||
tester.InitServer(fmt.Sprintf(`
|
||||
{
|
||||
admin localhost:2999
|
||||
http_port 9080
|
||||
https_port 9443
|
||||
grace_period 1ns
|
||||
skip_install_trust
|
||||
servers :9443 {
|
||||
protocols h2
|
||||
}
|
||||
}
|
||||
|
||||
https://localhost:9443 {
|
||||
reverse_proxy %s
|
||||
}
|
||||
`, backend.addr), "caddyfile")
|
||||
|
||||
const payload = "extended-connect-echo\n"
|
||||
if err := assertExtendedConnectH2Echo("localhost:9443", payload); err != nil {
|
||||
if errors.Is(err, errExtendedConnectUnsupportedByPeer) {
|
||||
t.Skipf("skipping extended CONNECT integration test: %v", err)
|
||||
}
|
||||
t.Fatalf("extended connect h2 echo failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertExtendedConnectH2Echo(addr, payload string) error {
|
||||
conn, err := tlsDialH2(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing h2 tls: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return fmt.Errorf("setting deadline: %w", err)
|
||||
}
|
||||
|
||||
fr := http2.NewFramer(conn, conn)
|
||||
|
||||
if _, err := conn.Write([]byte(http2.ClientPreface)); err != nil {
|
||||
return fmt.Errorf("writing client preface: %w", err)
|
||||
}
|
||||
if err := fr.WriteSettings(http2.Setting{ID: http2.SettingEnableConnectProtocol, Val: 1}); err != nil {
|
||||
return fmt.Errorf("writing client settings: %w", err)
|
||||
}
|
||||
|
||||
supported, err := waitForServerSettings(fr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !supported {
|
||||
return errExtendedConnectUnsupportedByPeer
|
||||
}
|
||||
if err := waitForSettingsAck(fr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeExtendedConnectHeaders(fr, addr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := readResponseStatus(fr, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if status != "200" {
|
||||
return fmt.Errorf("unexpected extended connect status: got=%s want=200", status)
|
||||
}
|
||||
|
||||
if err := fr.WriteData(1, false, []byte(payload)); err != nil {
|
||||
return fmt.Errorf("writing stream data: %w", err)
|
||||
}
|
||||
|
||||
echo, err := readStreamData(fr, 1, len(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if echo != payload {
|
||||
return fmt.Errorf("unexpected echoed payload: got=%q want=%q", echo, payload)
|
||||
}
|
||||
|
||||
_ = fr.WriteRSTStream(1, http2.ErrCodeNo)
|
||||
return nil
|
||||
}
|
||||
|
||||
func tlsDialH2(addr string) (net.Conn, error) {
|
||||
var lastErr error
|
||||
for i := 0; i < 30; i++ {
|
||||
dialer := &net.Dialer{Timeout: 2 * time.Second}
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
|
||||
ServerName: "localhost",
|
||||
InsecureSkipVerify: true,
|
||||
NextProtos: []string{"h2"},
|
||||
})
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func waitForServerSettings(fr *http2.Framer) (bool, error) {
|
||||
for {
|
||||
frame, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("reading frame before connect: %w", err)
|
||||
}
|
||||
settings, ok := frame.(*http2.SettingsFrame)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if settings.IsAck() {
|
||||
continue
|
||||
}
|
||||
|
||||
supported := false
|
||||
if err := settings.ForeachSetting(func(s http2.Setting) error {
|
||||
if s.ID == http2.SettingEnableConnectProtocol && s.Val == 1 {
|
||||
supported = true
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return false, fmt.Errorf("reading server settings: %w", err)
|
||||
}
|
||||
|
||||
if err := fr.WriteSettingsAck(); err != nil {
|
||||
return false, fmt.Errorf("writing settings ack: %w", err)
|
||||
}
|
||||
return supported, nil
|
||||
}
|
||||
}
|
||||
|
||||
func waitForSettingsAck(fr *http2.Framer) error {
|
||||
for {
|
||||
frame, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading settings ack: %w", err)
|
||||
}
|
||||
settings, ok := frame.(*http2.SettingsFrame)
|
||||
if ok && settings.IsAck() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func writeExtendedConnectHeaders(fr *http2.Framer, addr string) error {
|
||||
var hb bytes.Buffer
|
||||
enc := hpack.NewEncoder(&hb)
|
||||
for _, hf := range []hpack.HeaderField{
|
||||
{Name: ":method", Value: "CONNECT"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":authority", Value: addr},
|
||||
{Name: ":path", Value: "/upgrade"},
|
||||
{Name: ":protocol", Value: "websocket"},
|
||||
} {
|
||||
if err := enc.WriteField(hf); err != nil {
|
||||
return fmt.Errorf("encoding request headers: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fr.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: 1,
|
||||
BlockFragment: hb.Bytes(),
|
||||
EndHeaders: true,
|
||||
EndStream: false,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("writing extended connect headers: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readResponseStatus(fr *http2.Framer, streamID uint32) (string, error) {
|
||||
var block bytes.Buffer
|
||||
|
||||
for {
|
||||
frame, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading response headers: %w", err)
|
||||
}
|
||||
if rst, ok := frame.(*http2.RSTStreamFrame); ok && rst.StreamID == streamID {
|
||||
return "", fmt.Errorf("stream reset before response headers: %s", rst.ErrCode)
|
||||
}
|
||||
|
||||
h, ok := frame.(*http2.HeadersFrame)
|
||||
if !ok || h.StreamID != streamID {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := block.Write(h.HeaderBlockFragment()); err != nil {
|
||||
return "", fmt.Errorf("buffering response header fragment: %w", err)
|
||||
}
|
||||
for !h.HeadersEnded() {
|
||||
next, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading continuation frame: %w", err)
|
||||
}
|
||||
c, ok := next.(*http2.ContinuationFrame)
|
||||
if !ok || c.StreamID != streamID {
|
||||
continue
|
||||
}
|
||||
if _, err := block.Write(c.HeaderBlockFragment()); err != nil {
|
||||
return "", fmt.Errorf("buffering continuation fragment: %w", err)
|
||||
}
|
||||
if c.HeadersEnded() {
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var status string
|
||||
dec := hpack.NewDecoder(4096, func(f hpack.HeaderField) {
|
||||
if f.Name == ":status" {
|
||||
status = f.Value
|
||||
}
|
||||
})
|
||||
if _, err := dec.Write(block.Bytes()); err != nil {
|
||||
return "", fmt.Errorf("decoding response header block: %w", err)
|
||||
}
|
||||
if status == "" {
|
||||
return "", fmt.Errorf("missing :status in response headers")
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func readStreamData(fr *http2.Framer, streamID uint32, n int) (string, error) {
|
||||
buf := make([]byte, 0, n)
|
||||
for len(buf) < n {
|
||||
frame, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading stream data: %w", err)
|
||||
}
|
||||
d, ok := frame.(*http2.DataFrame)
|
||||
if !ok || d.StreamID != streamID {
|
||||
continue
|
||||
}
|
||||
buf = append(buf, d.Data()...)
|
||||
}
|
||||
return string(buf[:n]), nil
|
||||
}
|
||||
|
||||
type websocketUpgradeEchoBackend struct {
|
||||
addr string
|
||||
ln net.Listener
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
func newWebsocketUpgradeEchoBackend(t *testing.T) *websocketUpgradeEchoBackend {
|
||||
t.Helper()
|
||||
|
||||
backend := &websocketUpgradeEchoBackend{}
|
||||
backend.server = &http.Server{
|
||||
Handler: http.HandlerFunc(backend.serveHTTP),
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listening for websocket backend: %v", err)
|
||||
}
|
||||
backend.ln = ln
|
||||
backend.addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
_ = backend.server.Serve(ln)
|
||||
}()
|
||||
|
||||
return backend
|
||||
}
|
||||
|
||||
func (b *websocketUpgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")
|
||||
_ = rw.Flush()
|
||||
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
_, _ = io.Copy(conn, conn)
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *websocketUpgradeEchoBackend) Close() {
|
||||
_ = b.server.Close()
|
||||
_ = b.ln.Close()
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddytest"
|
||||
)
|
||||
|
||||
func TestReverseProxyUpgradeWithEncode(t *testing.T) {
|
||||
tester := caddytest.NewTester(t)
|
||||
backend := newUpgradeEchoBackend(t)
|
||||
defer backend.Close()
|
||||
|
||||
tester.InitServer(fmt.Sprintf(`
|
||||
{
|
||||
admin localhost:2999
|
||||
http_port 9080
|
||||
https_port 9443
|
||||
grace_period 1ns
|
||||
skip_install_trust
|
||||
}
|
||||
|
||||
localhost:9080 {
|
||||
route {
|
||||
encode gzip
|
||||
reverse_proxy %s
|
||||
}
|
||||
}
|
||||
`, backend.addr), "caddyfile")
|
||||
|
||||
client := newUpgradedStreamClientWithHeaders(t, map[string]string{
|
||||
"Accept-Encoding": "gzip",
|
||||
})
|
||||
defer client.Close()
|
||||
|
||||
if err := client.echo("encode-upgrade\n"); err != nil {
|
||||
t.Fatalf("upgraded stream echo through encode failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyUpgradeWithInterceptHandleResponse(t *testing.T) {
|
||||
tester := caddytest.NewTester(t)
|
||||
backend := newUpgradeEchoBackend(t)
|
||||
defer backend.Close()
|
||||
|
||||
tester.InitServer(fmt.Sprintf(`
|
||||
{
|
||||
admin localhost:2999
|
||||
http_port 9080
|
||||
https_port 9443
|
||||
grace_period 1ns
|
||||
skip_install_trust
|
||||
}
|
||||
|
||||
localhost:9080 {
|
||||
route {
|
||||
intercept {
|
||||
@upgrade status 101
|
||||
handle_response @upgrade {
|
||||
respond "should-not-run"
|
||||
}
|
||||
}
|
||||
reverse_proxy %s
|
||||
}
|
||||
}
|
||||
`, backend.addr), "caddyfile")
|
||||
|
||||
client := newUpgradedStreamClientWithHeaders(t, nil)
|
||||
defer client.Close()
|
||||
|
||||
if err := client.echo("intercept-upgrade\n"); err != nil {
|
||||
t.Fatalf("upgraded stream echo through intercept failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newUpgradedStreamClientWithHeaders(t *testing.T, extraHeaders map[string]string) *upgradedStreamClient {
|
||||
t.Helper()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dialing caddy: %v", err)
|
||||
}
|
||||
|
||||
requestLines := []string{
|
||||
"GET /upgrade HTTP/1.1",
|
||||
"Host: localhost:9080",
|
||||
"Connection: Upgrade",
|
||||
"Upgrade: stress-stream",
|
||||
}
|
||||
for k, v := range extraHeaders {
|
||||
requestLines = append(requestLines, k+": "+v)
|
||||
}
|
||||
requestLines = append(requestLines, "", "")
|
||||
|
||||
if _, err := io.WriteString(conn, strings.Join(requestLines, "\r\n")); err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("writing upgrade request: %v", err)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
tproto := textproto.NewReader(reader)
|
||||
statusLine, err := tproto.ReadLine()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("reading upgrade status line: %v", err)
|
||||
}
|
||||
if !strings.Contains(statusLine, "101") {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("unexpected upgrade status: %s", statusLine)
|
||||
}
|
||||
|
||||
headers, err := tproto.ReadMIMEHeader()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("reading upgrade headers: %v", err)
|
||||
}
|
||||
if !strings.EqualFold(headers.Get("Connection"), "Upgrade") {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("unexpected upgrade response headers: %v", headers)
|
||||
}
|
||||
|
||||
return &upgradedStreamClient{conn: conn, reader: reader}
|
||||
}
|
||||
@@ -1,504 +0,0 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"runtime/pprof"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddytest"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultStressStreamCount = 1
|
||||
defaultStressReloadCount = 1
|
||||
defaultStressCloseDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
func TestReverseProxyReloadStressUpgradedStreamsHeapProfiles(t *testing.T) {
|
||||
tester := caddytest.NewTester(t).WithDefaultOverrides(caddytest.Config{
|
||||
LoadRequestTimeout: 30 * time.Second,
|
||||
TestRequestTimeout: 30 * time.Second,
|
||||
})
|
||||
|
||||
backend := newUpgradeEchoBackend(t)
|
||||
defer backend.Close()
|
||||
|
||||
// Three scenarios, each sequential so they don't share Caddy state:
|
||||
//
|
||||
// legacy – no delay, close on reload immediately (old default)
|
||||
// close_delay – stream_close_delay, the old "keep-alive workaround"
|
||||
// detached – stream_detached, the new explicit detached flag
|
||||
//
|
||||
// Reloads are spread across time and interleaved with echo-checks so
|
||||
// stream health is exercised at each reload boundary, not only at the end.
|
||||
legacy := runReloadStress(t, tester, backend.addr, "legacy", false, 0)
|
||||
closeDelay := runReloadStress(t, tester, backend.addr, "close_delay", false, stressCloseDelay(t))
|
||||
detached := runReloadStress(t, tester, backend.addr, "detached", true, 0)
|
||||
|
||||
if legacy.aliveAfterReloads != 0 {
|
||||
t.Fatalf("legacy mode left %d upgraded streams alive after reloads", legacy.aliveAfterReloads)
|
||||
}
|
||||
if closeDelay.aliveBeforeDelayExpiry == 0 {
|
||||
t.Fatalf("close_delay mode: all streams closed before delay expired (expected them alive)")
|
||||
}
|
||||
if closeDelay.aliveAfterReloads != 0 {
|
||||
t.Fatalf("close_delay mode left %d upgraded streams alive after delay expiry", closeDelay.aliveAfterReloads)
|
||||
}
|
||||
if detached.aliveAfterReloads != detached.streamCount {
|
||||
t.Fatalf("detached mode kept %d/%d upgraded streams alive after reloads", detached.aliveAfterReloads, detached.streamCount)
|
||||
}
|
||||
|
||||
t.Logf("legacy heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
|
||||
formatBytes(legacy.beforeReload.HeapInuse),
|
||||
formatBytes(legacy.midReload.HeapInuse),
|
||||
formatBytes(legacy.afterReload.HeapInuse),
|
||||
formatBytesDiff(legacy.beforeReload.HeapInuse, legacy.afterReload.HeapInuse),
|
||||
legacy.beforeReload.HeapObjects, legacy.afterReload.HeapObjects,
|
||||
legacy.beforeReload.handlerFrames, legacy.afterReload.handlerFrames,
|
||||
)
|
||||
t.Logf("close_delay heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
|
||||
formatBytes(closeDelay.beforeReload.HeapInuse),
|
||||
formatBytes(closeDelay.midReload.HeapInuse),
|
||||
formatBytes(closeDelay.afterReload.HeapInuse),
|
||||
formatBytesDiff(closeDelay.beforeReload.HeapInuse, closeDelay.afterReload.HeapInuse),
|
||||
closeDelay.beforeReload.HeapObjects, closeDelay.afterReload.HeapObjects,
|
||||
closeDelay.beforeReload.handlerFrames, closeDelay.afterReload.handlerFrames,
|
||||
)
|
||||
t.Logf("detached heap: before=%s mid=%s after=%s delta(before→after)=%s objects(before=%d after=%d) handler_frames(before=%d after=%d)",
|
||||
formatBytes(detached.beforeReload.HeapInuse),
|
||||
formatBytes(detached.midReload.HeapInuse),
|
||||
formatBytes(detached.afterReload.HeapInuse),
|
||||
formatBytesDiff(detached.beforeReload.HeapInuse, detached.afterReload.HeapInuse),
|
||||
detached.beforeReload.HeapObjects, detached.afterReload.HeapObjects,
|
||||
detached.beforeReload.handlerFrames, detached.afterReload.handlerFrames,
|
||||
)
|
||||
}
|
||||
|
||||
type stressRunResult struct {
|
||||
streamCount int
|
||||
aliveAfterReloads int
|
||||
aliveBeforeDelayExpiry int // only meaningful for close_delay mode
|
||||
beforeReload heapSnapshot
|
||||
midReload heapSnapshot // after all reloads, before delay expiry clean-up
|
||||
afterReload heapSnapshot // after all streams have been fully cleaned up
|
||||
}
|
||||
|
||||
type heapSnapshot struct {
|
||||
HeapInuse uint64
|
||||
HeapObjects uint64
|
||||
handlerFrames int
|
||||
profileBytes int
|
||||
}
|
||||
|
||||
// runReloadStress opens streamCount upgraded streams, then performs reloadCount
|
||||
// config reloads spread over time. An echo check is performed every 6 reloads so
|
||||
// stream health is exercised at each reload boundary rather than only at the end.
|
||||
// closeDelay mirrors the stream_close_delay config option; pass 0 to disable.
|
||||
func runReloadStress(t *testing.T, tester *caddytest.Tester, backendAddr, mode string, detach bool, closeDelay time.Duration) stressRunResult {
|
||||
t.Helper()
|
||||
|
||||
const echoEvery = 6 // perform an echo check every N reloads
|
||||
|
||||
streamCount := envIntOrDefault(t, "CADDY_STRESS_STREAM_COUNT", defaultStressStreamCount)
|
||||
reloadCount := envIntOrDefault(t, "CADDY_STRESS_RELOAD_COUNT", defaultStressReloadCount)
|
||||
|
||||
tester.InitServer(reloadStressConfig(backendAddr, detach, closeDelay, 0), "caddyfile")
|
||||
|
||||
clients := make([]*upgradedStreamClient, 0, streamCount)
|
||||
for i := 0; i < streamCount; i++ {
|
||||
client := newUpgradedStreamClient(t)
|
||||
clients = append(clients, client)
|
||||
if err := client.echo(fmt.Sprintf("%s-warmup-%02d\n", mode, i)); err != nil {
|
||||
closeClients(clients)
|
||||
t.Fatalf("warmup echo failed in %s mode: %v", mode, err)
|
||||
}
|
||||
}
|
||||
defer closeClients(clients)
|
||||
|
||||
before := captureHeapSnapshot(t)
|
||||
|
||||
// Reloads are spread across time; between batches of echoEvery reloads we
|
||||
// pause briefly and measure stream health so the snapshot reflects real-world
|
||||
// reload cadence rather than a tight loop.
|
||||
for i := 1; i <= reloadCount; i++ {
|
||||
loadCaddyfileConfig(t, reloadStressConfig(backendAddr, detach, closeDelay, i))
|
||||
|
||||
// Small pause after each reload to let connection teardown propagate.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if i%echoEvery == 0 {
|
||||
alive := countAliveStreams(clients)
|
||||
t.Logf("%s mode: %d/%d streams alive after reload %d", mode, alive, streamCount, i)
|
||||
|
||||
// In detached mode, every stream must survive every reload (upstream unchanged).
|
||||
if detach {
|
||||
for j, client := range clients {
|
||||
if err := client.echo(fmt.Sprintf("%s-mid-%02d-%02d\n", mode, i, j)); err != nil {
|
||||
t.Fatalf("detached mode stream %d died at reload %d: %v", j, i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mid snapshot: after all reloads but before any close_delay timer has fired
|
||||
// (the delay is long enough to still be running at this point).
|
||||
mid := captureHeapSnapshot(t)
|
||||
|
||||
// For legacy mode: the reloads close streams immediately; wait for that to complete.
|
||||
// For close_delay mode: streams are still alive here; wait for the delay to fire.
|
||||
// For detached mode: streams survive indefinitely; no wait needed.
|
||||
var aliveBeforeDelayExpiry int
|
||||
aliveAfterReloads := countAliveStreams(clients)
|
||||
switch {
|
||||
case detach:
|
||||
// nothing to wait for
|
||||
case closeDelay > 0:
|
||||
// streams should still be alive at this point (delay hasn't expired)
|
||||
aliveBeforeDelayExpiry = aliveAfterReloads
|
||||
t.Logf("%s mode: %d/%d streams alive before close_delay expires; waiting %v for cleanup",
|
||||
mode, aliveBeforeDelayExpiry, streamCount, closeDelay)
|
||||
time.Sleep(closeDelay + 200*time.Millisecond)
|
||||
aliveAfterReloads = countAliveStreams(clients)
|
||||
default:
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for aliveAfterReloads > 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
aliveAfterReloads = countAliveStreams(clients)
|
||||
}
|
||||
}
|
||||
|
||||
after := captureHeapSnapshot(t)
|
||||
t.Logf("%s mode heap profile size: before=%dB mid=%dB after=%dB objects(before=%d mid=%d after=%d)",
|
||||
mode,
|
||||
before.profileBytes, mid.profileBytes, after.profileBytes,
|
||||
before.HeapObjects, mid.HeapObjects, after.HeapObjects,
|
||||
)
|
||||
|
||||
return stressRunResult{
|
||||
streamCount: streamCount,
|
||||
aliveAfterReloads: aliveAfterReloads,
|
||||
aliveBeforeDelayExpiry: aliveBeforeDelayExpiry,
|
||||
beforeReload: before,
|
||||
midReload: mid,
|
||||
afterReload: after,
|
||||
}
|
||||
}
|
||||
|
||||
func envIntOrDefault(t *testing.T, key string, def int) int {
|
||||
t.Helper()
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
return def
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v <= 0 {
|
||||
t.Fatalf("invalid %s=%q: must be a positive integer", key, raw)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func stressCloseDelay(t *testing.T) time.Duration {
|
||||
t.Helper()
|
||||
|
||||
const key = "CADDY_STRESS_CLOSE_DELAY"
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
return defaultStressCloseDelay
|
||||
}
|
||||
v, err := time.ParseDuration(raw)
|
||||
if err != nil || v <= 0 {
|
||||
t.Fatalf("invalid %s=%q: must be a positive duration", key, raw)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func loadCaddyfileConfig(t *testing.T, rawConfig string) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
req, err := http.NewRequest(http.MethodPost, "http://localhost:2999/load", strings.NewReader(rawConfig))
|
||||
if err != nil {
|
||||
t.Fatalf("creating load request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "text/caddyfile")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("loading config: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading load response: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("loading config failed: status=%d body=%s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func reloadStressConfig(backendAddr string, detach bool, closeDelay time.Duration, revision int) string {
|
||||
var directives string
|
||||
if detach {
|
||||
directives += "\n\t\tstream_detached"
|
||||
}
|
||||
if closeDelay > 0 {
|
||||
directives += fmt.Sprintf("\n\t\tstream_close_delay %s", closeDelay)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
{
|
||||
admin localhost:2999
|
||||
http_port 9080
|
||||
https_port 9443
|
||||
grace_period 1ns
|
||||
skip_install_trust
|
||||
}
|
||||
|
||||
localhost:9080 {
|
||||
reverse_proxy %s {
|
||||
header_up X-Reload-Revision %d%s
|
||||
}
|
||||
}
|
||||
`, backendAddr, revision, directives)
|
||||
}
|
||||
|
||||
func captureHeapSnapshot(t *testing.T) heapSnapshot {
|
||||
t.Helper()
|
||||
|
||||
runtime.GC()
|
||||
debug.FreeOSMemory()
|
||||
|
||||
var mem runtime.MemStats
|
||||
runtime.ReadMemStats(&mem)
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := pprof.Lookup("heap").WriteTo(&buf, 1); err != nil {
|
||||
t.Fatalf("capturing heap profile: %v", err)
|
||||
}
|
||||
profile := buf.String()
|
||||
|
||||
return heapSnapshot{
|
||||
HeapInuse: mem.HeapInuse,
|
||||
HeapObjects: mem.HeapObjects,
|
||||
handlerFrames: strings.Count(profile, "modules/caddyhttp/reverseproxy.(*Handler)"),
|
||||
profileBytes: buf.Len(),
|
||||
}
|
||||
}
|
||||
|
||||
func countAliveStreams(clients []*upgradedStreamClient) int {
|
||||
alive := 0
|
||||
for index, client := range clients {
|
||||
if err := client.echo(fmt.Sprintf("alive-check-%02d\n", index)); err == nil {
|
||||
alive++
|
||||
}
|
||||
}
|
||||
return alive
|
||||
}
|
||||
|
||||
func closeClients(clients []*upgradedStreamClient) {
|
||||
for _, client := range clients {
|
||||
if client != nil {
|
||||
_ = client.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatBytes(value uint64) string {
|
||||
const unit = 1024
|
||||
if value < unit {
|
||||
return fmt.Sprintf("%d B", value)
|
||||
}
|
||||
div, exp := uint64(unit), 0
|
||||
for n := value / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(value)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func formatBytesDiff(before, after uint64) string {
|
||||
if after >= before {
|
||||
return "+" + formatBytes(after-before)
|
||||
}
|
||||
return "-" + formatBytes(before-after)
|
||||
}
|
||||
|
||||
type upgradedStreamClient struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newUpgradedStreamClient(t *testing.T) *upgradedStreamClient {
|
||||
t.Helper()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", "127.0.0.1:9080", 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dialing caddy: %v", err)
|
||||
}
|
||||
|
||||
request := strings.Join([]string{
|
||||
"GET /upgrade HTTP/1.1",
|
||||
"Host: localhost:9080",
|
||||
"Connection: Upgrade",
|
||||
"Upgrade: stress-stream",
|
||||
"",
|
||||
"",
|
||||
}, "\r\n")
|
||||
if _, err := io.WriteString(conn, request); err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("writing upgrade request: %v", err)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
tproto := textproto.NewReader(reader)
|
||||
statusLine, err := tproto.ReadLine()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("reading upgrade status line: %v", err)
|
||||
}
|
||||
if !strings.Contains(statusLine, "101") {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("unexpected upgrade status: %s", statusLine)
|
||||
}
|
||||
|
||||
headers, err := tproto.ReadMIMEHeader()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("reading upgrade headers: %v", err)
|
||||
}
|
||||
if !strings.EqualFold(headers.Get("Connection"), "Upgrade") {
|
||||
_ = conn.Close()
|
||||
t.Fatalf("unexpected upgrade response headers: %v", headers)
|
||||
}
|
||||
|
||||
return &upgradedStreamClient{conn: conn, reader: reader}
|
||||
}
|
||||
|
||||
func (c *upgradedStreamClient) echo(payload string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
deadline := time.Now().Add(1 * time.Second)
|
||||
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.WriteString(c.conn, payload); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(c.reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if string(buf) != payload {
|
||||
return fmt.Errorf("unexpected echoed payload: got %q want %q", string(buf), payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *upgradedStreamClient) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
type upgradeEchoBackend struct {
|
||||
addr string
|
||||
ln net.Listener
|
||||
mu sync.Mutex
|
||||
conns map[net.Conn]struct{}
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
func newUpgradeEchoBackend(t *testing.T) *upgradeEchoBackend {
|
||||
t.Helper()
|
||||
|
||||
backend := &upgradeEchoBackend{conns: make(map[net.Conn]struct{})}
|
||||
backend.server = &http.Server{
|
||||
Handler: http.HandlerFunc(backend.serveHTTP),
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listening for backend: %v", err)
|
||||
}
|
||||
backend.ln = ln
|
||||
backend.addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
_ = backend.server.Serve(ln)
|
||||
}()
|
||||
|
||||
return backend
|
||||
}
|
||||
|
||||
func (b *upgradeEchoBackend) serveHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.EqualFold(r.Header.Get("Connection"), "Upgrade") || !strings.EqualFold(r.Header.Get("Upgrade"), "stress-stream") {
|
||||
http.Error(w, "upgrade required", http.StatusUpgradeRequired)
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
conn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.trackConn(conn)
|
||||
_, _ = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: stress-stream\r\n\r\n")
|
||||
_ = rw.Flush()
|
||||
|
||||
go func() {
|
||||
defer b.untrackConn(conn)
|
||||
defer conn.Close()
|
||||
_, _ = io.Copy(conn, conn)
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *upgradeEchoBackend) trackConn(conn net.Conn) {
|
||||
b.mu.Lock()
|
||||
b.conns[conn] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *upgradeEchoBackend) untrackConn(conn net.Conn) {
|
||||
b.mu.Lock()
|
||||
delete(b.conns, conn)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *upgradeEchoBackend) Close() {
|
||||
_ = b.server.Close()
|
||||
_ = b.ln.Close()
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
for conn := range b.conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
clear(b.conns)
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddycmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitModule(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedModule string
|
||||
expectedVersion string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "simple module without version",
|
||||
input: "github.com/caddyserver/caddy",
|
||||
expectedModule: "github.com/caddyserver/caddy",
|
||||
expectedVersion: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with version",
|
||||
input: "github.com/caddyserver/caddy@v2.0.0",
|
||||
expectedModule: "github.com/caddyserver/caddy",
|
||||
expectedVersion: "v2.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with semantic version",
|
||||
input: "github.com/user/module@v1.2.3",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "v1.2.3",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with prerelease version",
|
||||
input: "github.com/user/module@v1.0.0-beta.1",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "v1.0.0-beta.1",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with commit hash",
|
||||
input: "github.com/user/module@abc123def",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "abc123def",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with @ in path and version",
|
||||
input: "github.com/@user/module@v1.0.0",
|
||||
expectedModule: "github.com/@user/module",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with multiple @ in path",
|
||||
input: "github.com/@org/@user/module@v2.3.4",
|
||||
expectedModule: "github.com/@org/@user/module",
|
||||
expectedVersion: "v2.3.4",
|
||||
expectError: false,
|
||||
},
|
||||
// TODO: decide on the behavior for this case; it fails currently
|
||||
// {
|
||||
// name: "module with @ in path but no version",
|
||||
// input: "github.com/@user/module",
|
||||
// expectedModule: "github.com/@user/module",
|
||||
// expectedVersion: "",
|
||||
// expectError: false,
|
||||
// },
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expectedModule: "",
|
||||
expectedVersion: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "only @ symbol",
|
||||
input: "@",
|
||||
expectedModule: "",
|
||||
expectedVersion: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "@ at start",
|
||||
input: "@v1.0.0",
|
||||
expectedModule: "",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "@ at end",
|
||||
input: "github.com/user/module@",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "multiple consecutive @",
|
||||
input: "github.com/user/module@@v1.0.0",
|
||||
expectedModule: "github.com/user/module@",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "version with latest tag",
|
||||
input: "github.com/user/module@latest",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "latest",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "long module path",
|
||||
input: "github.com/organization/team/project/subproject/module@v3.14.159",
|
||||
expectedModule: "github.com/organization/team/project/subproject/module",
|
||||
expectedVersion: "v3.14.159",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with dots in name",
|
||||
input: "github.com/user/my.module.name@v1.0",
|
||||
expectedModule: "github.com/user/my.module.name",
|
||||
expectedVersion: "v1.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with hyphens",
|
||||
input: "github.com/user/my-module-name@v1.0.0",
|
||||
expectedModule: "github.com/user/my-module-name",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "gitlab module",
|
||||
input: "gitlab.com/user/module@v2.0.0",
|
||||
expectedModule: "gitlab.com/user/module",
|
||||
expectedVersion: "v2.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "bitbucket module",
|
||||
input: "bitbucket.org/user/module@v1.5.0",
|
||||
expectedModule: "bitbucket.org/user/module",
|
||||
expectedVersion: "v1.5.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "custom domain",
|
||||
input: "example.com/custom/module@v1.0.0",
|
||||
expectedModule: "example.com/custom/module",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
module, version, err := splitModule(tt.input)
|
||||
|
||||
// Check error expectation
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check module
|
||||
if module != tt.expectedModule {
|
||||
t.Errorf("module: got %q, want %q", module, tt.expectedModule)
|
||||
}
|
||||
|
||||
// Check version
|
||||
if version != tt.expectedVersion {
|
||||
t.Errorf("version: got %q, want %q", version, tt.expectedVersion)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitModule_ErrorCases(t *testing.T) {
|
||||
errorCases := []string{
|
||||
"",
|
||||
"@",
|
||||
"@version",
|
||||
"@v1.0.0",
|
||||
}
|
||||
|
||||
for _, tc := range errorCases {
|
||||
t.Run("error_"+tc, func(t *testing.T) {
|
||||
_, _, err := splitModule(tc)
|
||||
if err == nil {
|
||||
t.Errorf("splitModule(%q) should return error", tc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSplitModule benchmarks the splitModule function
|
||||
func BenchmarkSplitModule(b *testing.B) {
|
||||
testCases := []string{
|
||||
"github.com/user/module",
|
||||
"github.com/user/module@v1.0.0",
|
||||
"github.com/@org/@user/module@v2.3.4",
|
||||
"github.com/organization/team/project/subproject/module@v3.14.159",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
b.Run(tc, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
splitModule(tc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+720
@@ -0,0 +1,720 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConfig_Start_Stop_Basic(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Admin: &AdminConfig{Disabled: true}, // Disable admin to avoid port conflicts
|
||||
}
|
||||
|
||||
ctx, err := run(cfg, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run config: %v", err)
|
||||
}
|
||||
|
||||
// Verify context is valid
|
||||
if ctx.cfg == nil {
|
||||
t.Error("Expected non-nil config in context")
|
||||
}
|
||||
|
||||
// Stop the config
|
||||
unsyncedStop(ctx)
|
||||
|
||||
// Verify cleanup was called
|
||||
if ctx.cfg.cancelFunc == nil {
|
||||
t.Error("Expected cancel function to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_InvalidConfig(t *testing.T) {
|
||||
// Create a config with an invalid app module
|
||||
cfg := &Config{
|
||||
AppsRaw: ModuleMap{
|
||||
"non-existent-app": json.RawMessage(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
err := Validate(cfg)
|
||||
if err == nil {
|
||||
t.Error("Expected validation error for invalid app module")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_ValidConfig(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Admin: &AdminConfig{Disabled: true},
|
||||
}
|
||||
|
||||
err := Validate(cfg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected validation error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeConfig_ConcurrentAccess(t *testing.T) {
|
||||
// Save original config state
|
||||
originalRawCfg := rawCfg[rawConfigKey]
|
||||
originalRawCfgJSON := rawCfgJSON
|
||||
defer func() {
|
||||
rawCfg[rawConfigKey] = originalRawCfg
|
||||
rawCfgJSON = originalRawCfgJSON
|
||||
}()
|
||||
|
||||
// Initialize with a basic config
|
||||
initialCfg := map[string]any{
|
||||
"test": "value",
|
||||
}
|
||||
rawCfg[rawConfigKey] = initialCfg
|
||||
|
||||
const numGoroutines = 10 // Reduced for more controlled testing
|
||||
var wg sync.WaitGroup
|
||||
errors := make([]error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Only test read operations to avoid complex state changes
|
||||
// that could cause nil pointer issues in concurrent scenarios
|
||||
var buf bytes.Buffer
|
||||
errors[index] = readConfig("/"+rawConfigKey+"/test", &buf)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check that read operations succeeded
|
||||
for i, err := range errors {
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: Unexpected read error: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeConfig_MethodValidation(t *testing.T) {
|
||||
// Save original config state
|
||||
originalRawCfg := rawCfg[rawConfigKey]
|
||||
defer func() {
|
||||
rawCfg[rawConfigKey] = originalRawCfg
|
||||
}()
|
||||
|
||||
// Set up a simple valid config for testing
|
||||
rawCfg[rawConfigKey] = map[string]any{}
|
||||
|
||||
tests := []struct {
|
||||
method string
|
||||
expectErr bool
|
||||
}{
|
||||
{http.MethodPost, false},
|
||||
{http.MethodPut, true}, // because key 'admin' already exists
|
||||
{http.MethodPatch, false},
|
||||
{http.MethodDelete, false},
|
||||
{http.MethodGet, true},
|
||||
{http.MethodHead, true},
|
||||
{http.MethodOptions, true},
|
||||
{http.MethodConnect, true},
|
||||
{http.MethodTrace, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.method, func(t *testing.T) {
|
||||
// Use a simple admin config path that won't cause complex validation
|
||||
err := changeConfig(test.method, "/"+rawConfigKey+"/admin", []byte(`{"disabled": true}`), "", false)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error for invalid method")
|
||||
}
|
||||
if !test.expectErr && err != nil && (err != errSameConfig) {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeConfig_IfMatchHeader_Validation(t *testing.T) {
|
||||
// Set up initial config
|
||||
initialCfg := map[string]any{"test": "value"}
|
||||
rawCfg[rawConfigKey] = initialCfg
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ifMatch string
|
||||
expectErr bool
|
||||
expectStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "malformed - no quotes",
|
||||
ifMatch: "path hash",
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "malformed - single quote",
|
||||
ifMatch: `"path hash`,
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "malformed - wrong number of parts",
|
||||
ifMatch: `"path"`,
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "malformed - too many parts",
|
||||
ifMatch: `"path hash extra"`,
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "wrong hash",
|
||||
ifMatch: `"/config/test wronghash"`,
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusPreconditionFailed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
err := changeConfig(http.MethodPost, "/"+rawConfigKey+"/test", []byte(`"newvalue"`), test.ifMatch, false)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if test.expectErr && err != nil {
|
||||
if apiErr, ok := err.(APIError); ok {
|
||||
if apiErr.HTTPStatus != test.expectStatusCode {
|
||||
t.Errorf("Expected status %d, got %d", test.expectStatusCode, apiErr.HTTPStatus)
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected APIError type")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexConfigObjects_Basic(t *testing.T) {
|
||||
config := map[string]any{
|
||||
"app1": map[string]any{
|
||||
"@id": "my-app",
|
||||
"config": "value",
|
||||
},
|
||||
"nested": map[string]any{
|
||||
"array": []any{
|
||||
map[string]any{
|
||||
"@id": "nested-item",
|
||||
"data": "test",
|
||||
},
|
||||
map[string]any{
|
||||
"@id": 123.0, // JSON numbers are float64
|
||||
"more": "data",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
index := make(map[string]string)
|
||||
err := indexConfigObjects(config, "/config", index)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := map[string]string{
|
||||
"my-app": "/config/app1",
|
||||
"nested-item": "/config/nested/array/0",
|
||||
"123": "/config/nested/array/1",
|
||||
}
|
||||
|
||||
if len(index) != len(expected) {
|
||||
t.Errorf("Expected %d indexed items, got %d", len(expected), len(index))
|
||||
}
|
||||
|
||||
for id, expectedPath := range expected {
|
||||
if actualPath, exists := index[id]; !exists || actualPath != expectedPath {
|
||||
t.Errorf("ID %s: expected path '%s', got '%s'", id, expectedPath, actualPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexConfigObjects_InvalidID(t *testing.T) {
|
||||
config := map[string]any{
|
||||
"app": map[string]any{
|
||||
"@id": map[string]any{"invalid": "id"}, // Invalid ID type
|
||||
},
|
||||
}
|
||||
|
||||
index := make(map[string]string)
|
||||
err := indexConfigObjects(config, "/config", index)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid ID type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_AppStartFailure(t *testing.T) {
|
||||
// Register a mock app that fails to start
|
||||
RegisterModule(&failingApp{})
|
||||
defer func() {
|
||||
// Clean up module registry
|
||||
delete(modules, "failing-app")
|
||||
}()
|
||||
|
||||
cfg := &Config{
|
||||
Admin: &AdminConfig{Disabled: true},
|
||||
AppsRaw: ModuleMap{
|
||||
"failing-app": json.RawMessage(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := run(cfg, true)
|
||||
if err == nil {
|
||||
t.Error("Expected error when app fails to start")
|
||||
}
|
||||
|
||||
// Should contain the app name in the error
|
||||
if err.Error() == "" {
|
||||
t.Error("Expected descriptive error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_AppStopFailure_During_Cleanup(t *testing.T) {
|
||||
// Register apps where one fails to start and another fails to stop
|
||||
RegisterModule(&workingApp{})
|
||||
RegisterModule(&failingStopApp{})
|
||||
defer func() {
|
||||
delete(modules, "working-app")
|
||||
delete(modules, "failing-stop-app")
|
||||
}()
|
||||
|
||||
cfg := &Config{
|
||||
Admin: &AdminConfig{Disabled: true},
|
||||
AppsRaw: ModuleMap{
|
||||
"working-app": json.RawMessage(`{}`),
|
||||
"failing-stop-app": json.RawMessage(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
// Start both apps
|
||||
ctx, err := run(cfg, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error starting apps: %v", err)
|
||||
}
|
||||
|
||||
// Stop context - this should handle stop failures gracefully
|
||||
unsyncedStop(ctx)
|
||||
|
||||
// Test passed if we reach here without panic
|
||||
}
|
||||
|
||||
func TestProvisionContext_NilConfig(t *testing.T) {
|
||||
ctx, err := provisionContext(nil, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if ctx.cfg == nil {
|
||||
t.Error("Expected non-nil config even when input is nil")
|
||||
}
|
||||
|
||||
// Clean up
|
||||
// TODO: Investigate
|
||||
ctx.cfg.cancelFunc(nil)
|
||||
}
|
||||
|
||||
func TestDuration_UnmarshalJSON_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "integer nanoseconds",
|
||||
input: "1000000000",
|
||||
expected: time.Second,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "string duration",
|
||||
input: `"5m30s"`,
|
||||
expected: 5*time.Minute + 30*time.Second,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "days conversion",
|
||||
input: `"2d"`,
|
||||
expected: 48 * time.Hour,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed days and hours",
|
||||
input: `"1d12h"`,
|
||||
expected: 36 * time.Hour,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid duration",
|
||||
input: `"invalid"`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var d Duration
|
||||
err := d.UnmarshalJSON([]byte(test.input))
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !test.expectErr && time.Duration(d) != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, time.Duration(d))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_LongInput(t *testing.T) {
|
||||
// Test input length limit
|
||||
longInput := string(make([]byte, 1025)) // Exceeds 1024 limit
|
||||
for i := range longInput {
|
||||
longInput = longInput[:i] + "1"
|
||||
}
|
||||
longInput += "d"
|
||||
|
||||
_, err := ParseDuration(longInput)
|
||||
if err == nil {
|
||||
t.Error("Expected error for input longer than 1024 characters")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_Deterministic(t *testing.T) {
|
||||
// Test that Version() returns consistent results
|
||||
simple1, full1 := Version()
|
||||
simple2, full2 := Version()
|
||||
|
||||
if simple1 != simple2 {
|
||||
t.Errorf("Version() simple form not deterministic: '%s' != '%s'", simple1, simple2)
|
||||
}
|
||||
if full1 != full2 {
|
||||
t.Errorf("Version() full form not deterministic: '%s' != '%s'", full1, full2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstanceID_Consistency(t *testing.T) {
|
||||
// Test that InstanceID returns the same ID on subsequent calls
|
||||
id1, err := InstanceID()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get instance ID: %v", err)
|
||||
}
|
||||
|
||||
id2, err := InstanceID()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get instance ID on second call: %v", err)
|
||||
}
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("InstanceID not consistent: %v != %v", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveMetaFields_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no meta fields",
|
||||
input: `{"normal": "field"}`,
|
||||
expected: `{"normal": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "single @id field",
|
||||
input: `{"@id": "test", "other": "field"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "@id at beginning",
|
||||
input: `{"@id": "test", "other": "field"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "@id at end",
|
||||
input: `{"other": "field", "@id": "test"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "@id in middle",
|
||||
input: `{"first": "value", "@id": "test", "last": "value"}`,
|
||||
expected: `{"first": "value", "last": "value"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple @id fields",
|
||||
input: `{"@id": "test1", "other": "field", "@id": "test2"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "numeric @id",
|
||||
input: `{"@id": 123, "other": "field"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "nested objects with @id",
|
||||
input: `{"outer": {"@id": "nested", "data": "value"}}`,
|
||||
expected: `{"outer": {"data": "value"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := RemoveMetaFields([]byte(test.input))
|
||||
// resultStr := string(result)
|
||||
|
||||
// Parse both to ensure valid JSON and compare structures
|
||||
var expectedObj, resultObj any
|
||||
if err := json.Unmarshal([]byte(test.expected), &expectedObj); err != nil {
|
||||
t.Fatalf("Expected result is not valid JSON: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(result, &resultObj); err != nil {
|
||||
t.Fatalf("Result is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Note: We can't do exact string comparison due to potential field ordering
|
||||
// Instead, verify the structure matches
|
||||
expectedJSON, _ := json.Marshal(expectedObj)
|
||||
resultJSON, _ := json.Marshal(resultObj)
|
||||
|
||||
if string(expectedJSON) != string(resultJSON) {
|
||||
t.Errorf("Expected %s, got %s", string(expectedJSON), string(resultJSON))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsyncedConfigAccess_ArrayOperations_EdgeCases(t *testing.T) {
|
||||
// Test array boundary conditions and edge cases
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState map[string]any
|
||||
method string
|
||||
path string
|
||||
payload string
|
||||
expectErr bool
|
||||
expectState map[string]any
|
||||
}{
|
||||
{
|
||||
name: "delete from empty array",
|
||||
initialState: map[string]any{"arr": []any{}},
|
||||
method: http.MethodDelete,
|
||||
path: "/config/arr/0",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "access negative index",
|
||||
initialState: map[string]any{"arr": []any{"a", "b"}},
|
||||
method: http.MethodGet,
|
||||
path: "/config/arr/-1",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "put at index beyond end",
|
||||
initialState: map[string]any{"arr": []any{"a"}},
|
||||
method: http.MethodPut,
|
||||
path: "/config/arr/5",
|
||||
payload: `"new"`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "patch non-existent index",
|
||||
initialState: map[string]any{"arr": []any{"a"}},
|
||||
method: http.MethodPatch,
|
||||
path: "/config/arr/5",
|
||||
payload: `"new"`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "put at exact end of array",
|
||||
initialState: map[string]any{"arr": []any{"a", "b"}},
|
||||
method: http.MethodPut,
|
||||
path: "/config/arr/2",
|
||||
payload: `"c"`,
|
||||
expectState: map[string]any{"arr": []any{"a", "b", "c"}},
|
||||
},
|
||||
{
|
||||
name: "ellipses with non-array payload",
|
||||
initialState: map[string]any{"arr": []any{"a"}},
|
||||
method: http.MethodPost,
|
||||
path: "/config/arr/...",
|
||||
payload: `"not-array"`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Set up initial state
|
||||
rawCfg[rawConfigKey] = test.initialState
|
||||
|
||||
err := unsyncedConfigAccess(test.method, test.path, []byte(test.payload), nil)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if test.expectState != nil {
|
||||
// Compare resulting state
|
||||
expectedJSON, _ := json.Marshal(test.expectState)
|
||||
actualJSON, _ := json.Marshal(rawCfg[rawConfigKey])
|
||||
|
||||
if string(expectedJSON) != string(actualJSON) {
|
||||
t.Errorf("Expected state %s, got %s", string(expectedJSON), string(actualJSON))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitProcess_ConcurrentCalls(t *testing.T) {
|
||||
// Test that multiple concurrent calls to exitProcess are safe
|
||||
// We can't test the actual exit, but we can test the atomic flag
|
||||
|
||||
// Reset the exiting flag
|
||||
oldExiting := exiting
|
||||
exiting = new(int32)
|
||||
defer func() { exiting = oldExiting }()
|
||||
|
||||
const numGoroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
results := make([]bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
// Check the Exiting() function which reads the atomic flag
|
||||
wasExitingBefore := Exiting()
|
||||
|
||||
// This would call exitProcess, but we don't want to actually exit
|
||||
// So we just test the atomic operation directly
|
||||
results[index] = atomic.CompareAndSwapInt32(exiting, 0, 1)
|
||||
|
||||
wasExitingAfter := Exiting()
|
||||
|
||||
// At least one should succeed in setting the flag
|
||||
if !wasExitingBefore && wasExitingAfter && !results[index] {
|
||||
t.Errorf("Goroutine %d: Flag was set but CAS failed", index)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Exactly one goroutine should have successfully set the flag
|
||||
successCount := 0
|
||||
for _, success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
if successCount != 1 {
|
||||
t.Errorf("Expected exactly 1 successful flag set, got %d", successCount)
|
||||
}
|
||||
|
||||
// Flag should be set
|
||||
if !Exiting() {
|
||||
t.Error("Exiting flag should be set")
|
||||
}
|
||||
}
|
||||
|
||||
// Mock apps for testing
|
||||
type failingApp struct{}
|
||||
|
||||
func (fa *failingApp) CaddyModule() ModuleInfo {
|
||||
return ModuleInfo{
|
||||
ID: "failing-app",
|
||||
New: func() Module { return new(failingApp) },
|
||||
}
|
||||
}
|
||||
|
||||
func (fa *failingApp) Start() error {
|
||||
return fmt.Errorf("simulated start failure")
|
||||
}
|
||||
|
||||
func (fa *failingApp) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type workingApp struct{}
|
||||
|
||||
func (wa *workingApp) CaddyModule() ModuleInfo {
|
||||
return ModuleInfo{
|
||||
ID: "working-app",
|
||||
New: func() Module { return new(workingApp) },
|
||||
}
|
||||
}
|
||||
|
||||
func (wa *workingApp) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wa *workingApp) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type failingStopApp struct{}
|
||||
|
||||
func (fsa *failingStopApp) CaddyModule() ModuleInfo {
|
||||
return ModuleInfo{
|
||||
ID: "failing-stop-app",
|
||||
New: func() Module { return new(failingStopApp) },
|
||||
}
|
||||
}
|
||||
|
||||
func (fsa *failingStopApp) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fsa *failingStopApp) Stop() error {
|
||||
return fmt.Errorf("simulated stop failure")
|
||||
}
|
||||
@@ -0,0 +1,407 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseDuration_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "zero duration",
|
||||
input: "0",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "invalid format",
|
||||
input: "abc",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative days",
|
||||
input: "-2d",
|
||||
expected: -48 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "decimal days",
|
||||
input: "0.5d",
|
||||
expected: 12 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "large decimal days",
|
||||
input: "365.25d",
|
||||
expected: time.Duration(365.25*24) * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "multiple days in same string",
|
||||
input: "1d2d3d",
|
||||
expected: (24 * 6) * time.Hour, // 6 days total
|
||||
},
|
||||
{
|
||||
name: "days with other units",
|
||||
input: "1d30m15s",
|
||||
expected: 24*time.Hour + 30*time.Minute + 15*time.Second,
|
||||
},
|
||||
{
|
||||
name: "malformed days",
|
||||
input: "d",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid day value",
|
||||
input: "abcd",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "overflow protection",
|
||||
input: "9999999999999999999999999d",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero days",
|
||||
input: "0d",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "input at limit",
|
||||
input: strings.Repeat("1", 1024) + "ns",
|
||||
expectErr: true, // Likely to cause parsing error due to size
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result, err := ParseDuration(test.input)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !test.expectErr && result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_InputLengthLimit(t *testing.T) {
|
||||
// Test the 1024 character limit
|
||||
longInput := strings.Repeat("1", 1025) + "s"
|
||||
|
||||
_, err := ParseDuration(longInput)
|
||||
if err == nil {
|
||||
t.Error("Expected error for input longer than 1024 characters")
|
||||
}
|
||||
|
||||
expectedErrMsg := "parsing duration: input string too long"
|
||||
if err.Error() != expectedErrMsg {
|
||||
t.Errorf("Expected error message '%s', got '%s'", expectedErrMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_ComplexNumberFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
input: "+1d",
|
||||
expected: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
input: "-1.5d",
|
||||
expected: -36 * time.Hour,
|
||||
},
|
||||
{
|
||||
input: "1.0d",
|
||||
expected: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
input: "0.25d",
|
||||
expected: 6 * time.Hour,
|
||||
},
|
||||
{
|
||||
input: "1.5d30m",
|
||||
expected: 36*time.Hour + 30*time.Minute,
|
||||
},
|
||||
{
|
||||
input: "2.5d1h30m45s",
|
||||
expected: 60*time.Hour + time.Hour + 30*time.Minute + 45*time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
result, err := ParseDuration(test.input)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_UnmarshalJSON_TypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "null value",
|
||||
input: "null",
|
||||
expectErr: false,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "boolean value",
|
||||
input: "true",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "array value",
|
||||
input: `[1,2,3]`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "object value",
|
||||
input: `{"duration": "5m"}`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative integer",
|
||||
input: "-1000000000",
|
||||
expected: -time.Second,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero integer",
|
||||
input: "0",
|
||||
expected: 0,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "large integer",
|
||||
input: "9223372036854775807", // Max int64
|
||||
expected: time.Duration(math.MaxInt64),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "float as integer (invalid JSON for int)",
|
||||
input: "1.5",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "string with special characters",
|
||||
input: `"5m\"30s"`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "string with unicode",
|
||||
input: `"5m🚀"`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var d Duration
|
||||
err := d.UnmarshalJSON([]byte(test.input))
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !test.expectErr && time.Duration(d) != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, time.Duration(d))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_JSON_RoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
duration time.Duration
|
||||
asString bool
|
||||
}{
|
||||
{duration: 5 * time.Minute, asString: true},
|
||||
{duration: 24 * time.Hour, asString: false}, // Will be stored as nanoseconds
|
||||
{duration: 0, asString: false},
|
||||
{duration: -time.Hour, asString: true},
|
||||
{duration: time.Nanosecond, asString: false},
|
||||
{duration: time.Second, asString: false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.duration.String(), func(t *testing.T) {
|
||||
d := Duration(test.duration)
|
||||
|
||||
// Marshal to JSON
|
||||
jsonData, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var unmarshaled Duration
|
||||
err = unmarshaled.UnmarshalJSON(jsonData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Should be equal
|
||||
if time.Duration(unmarshaled) != test.duration {
|
||||
t.Errorf("Round trip failed: expected %v, got %v", test.duration, time.Duration(unmarshaled))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_Precision(t *testing.T) {
|
||||
// Test floating point precision with days
|
||||
tests := []struct {
|
||||
input string
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
input: "0.1d",
|
||||
expected: time.Duration(0.1 * 24 * float64(time.Hour)),
|
||||
},
|
||||
{
|
||||
input: "0.01d",
|
||||
expected: time.Duration(0.01 * 24 * float64(time.Hour)),
|
||||
},
|
||||
{
|
||||
input: "0.001d",
|
||||
expected: time.Duration(0.001 * 24 * float64(time.Hour)),
|
||||
},
|
||||
{
|
||||
input: "1.23456789d",
|
||||
expected: time.Duration(1.23456789 * 24 * float64(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
result, err := ParseDuration(test.input)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Allow for small floating point differences
|
||||
diff := result - test.expected
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
if diff > time.Nanosecond {
|
||||
t.Errorf("Expected %v, got %v (diff: %v)", test.expected, result, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_Boundary_Values(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "minimum day value",
|
||||
input: "0.000000001d", // Very small but valid
|
||||
},
|
||||
{
|
||||
name: "very large day value",
|
||||
input: "999999999999999999999d",
|
||||
expectErr: true, // Should overflow
|
||||
},
|
||||
{
|
||||
name: "negative zero",
|
||||
input: "-0d",
|
||||
},
|
||||
{
|
||||
name: "positive zero",
|
||||
input: "+0d",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
_, err := ParseDuration(test.input)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseDuration_SimpleDay(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ParseDuration("1d")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseDuration_ComplexDay(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ParseDuration("1.5d30m15.5s")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseDuration_MultipleDays(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ParseDuration("1d2d3d4d5d")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDuration_UnmarshalJSON_String(b *testing.B) {
|
||||
input := []byte(`"5m30s"`)
|
||||
var d Duration
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
d.UnmarshalJSON(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDuration_UnmarshalJSON_Integer(b *testing.B) {
|
||||
input := []byte("300000000000") // 5 minutes in nanoseconds
|
||||
var d Duration
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
d.UnmarshalJSON(input)
|
||||
}
|
||||
}
|
||||
+642
@@ -0,0 +1,642 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewEvent_Basic(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
eventName := "test.event"
|
||||
eventData := map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
}
|
||||
|
||||
event, err := NewEvent(ctx, eventName, eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
// Verify event properties
|
||||
if event.Name() != eventName {
|
||||
t.Errorf("Expected name '%s', got '%s'", eventName, event.Name())
|
||||
}
|
||||
|
||||
if event.Data == nil {
|
||||
t.Error("Expected non-nil data")
|
||||
}
|
||||
|
||||
if len(event.Data) != len(eventData) {
|
||||
t.Errorf("Expected %d data items, got %d", len(eventData), len(event.Data))
|
||||
}
|
||||
|
||||
for key, expectedValue := range eventData {
|
||||
if actualValue, exists := event.Data[key]; !exists || actualValue != expectedValue {
|
||||
t.Errorf("Data key '%s': expected %v, got %v", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify ID is generated
|
||||
if event.ID().String() == "" {
|
||||
t.Error("Event ID should not be empty")
|
||||
}
|
||||
|
||||
// Verify timestamp is recent
|
||||
if time.Since(event.Timestamp()) > time.Second {
|
||||
t.Error("Event timestamp should be recent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEvent_NameNormalization(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"UPPERCASE", "uppercase"},
|
||||
{"MixedCase", "mixedcase"},
|
||||
{"already.lower", "already.lower"},
|
||||
{"With-Dashes", "with-dashes"},
|
||||
{"With_Underscores", "with_underscores"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
event, err := NewEvent(ctx, test.input, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
if event.Name() != test.expected {
|
||||
t.Errorf("Expected normalized name '%s', got '%s'", test.expected, event.Name())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_CloudEvent_NilData(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// Should not panic with nil data
|
||||
if cloudEvent.Data == nil {
|
||||
t.Error("CloudEvent data should not be nil even with nil input")
|
||||
}
|
||||
|
||||
// Should be valid JSON
|
||||
var parsed any
|
||||
if err := json.Unmarshal(cloudEvent.Data, &parsed); err != nil {
|
||||
t.Errorf("CloudEvent data should be valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_CloudEvent_WithModule(t *testing.T) {
|
||||
// Create a context with a mock module
|
||||
mockMod := &mockModule{}
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Simulate module ancestry
|
||||
ctx.ancestry = []Module{mockMod}
|
||||
|
||||
event, err := NewEvent(ctx, "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// Source should be the module ID
|
||||
expectedSource := string(mockMod.CaddyModule().ID)
|
||||
if cloudEvent.Source != expectedSource {
|
||||
t.Errorf("Expected source '%s', got '%s'", expectedSource, cloudEvent.Source)
|
||||
}
|
||||
|
||||
// Origin should be the module
|
||||
if event.Origin() != mockMod {
|
||||
t.Error("Expected event origin to be the mock module")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_CloudEvent_Fields(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
eventName := "test.event"
|
||||
eventData := map[string]any{"test": "data"}
|
||||
|
||||
event, err := NewEvent(ctx, eventName, eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// Verify CloudEvent fields
|
||||
if cloudEvent.ID == "" {
|
||||
t.Error("CloudEvent ID should not be empty")
|
||||
}
|
||||
|
||||
if cloudEvent.Source != "caddy" {
|
||||
t.Errorf("Expected source 'caddy' for nil module, got '%s'", cloudEvent.Source)
|
||||
}
|
||||
|
||||
if cloudEvent.SpecVersion != "1.0" {
|
||||
t.Errorf("Expected spec version '1.0', got '%s'", cloudEvent.SpecVersion)
|
||||
}
|
||||
|
||||
if cloudEvent.Type != eventName {
|
||||
t.Errorf("Expected type '%s', got '%s'", eventName, cloudEvent.Type)
|
||||
}
|
||||
|
||||
if cloudEvent.DataContentType != "application/json" {
|
||||
t.Errorf("Expected content type 'application/json', got '%s'", cloudEvent.DataContentType)
|
||||
}
|
||||
|
||||
// Verify data is valid JSON
|
||||
var parsedData map[string]any
|
||||
if err := json.Unmarshal(cloudEvent.Data, &parsedData); err != nil {
|
||||
t.Errorf("CloudEvent data is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if parsedData["test"] != "data" {
|
||||
t.Errorf("Expected data to contain test='data', got %v", parsedData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_ConcurrentAccess(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "concurrent.test", map[string]any{
|
||||
"counter": 0,
|
||||
"data": "shared",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
const numGoroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Test concurrent read access to event properties
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// These should be safe for concurrent access
|
||||
_ = event.ID()
|
||||
_ = event.Name()
|
||||
_ = event.Timestamp()
|
||||
_ = event.Origin()
|
||||
_ = event.CloudEvent()
|
||||
|
||||
// Data map is not synchronized, so read-only access should be safe
|
||||
if data, exists := event.Data["data"]; !exists || data != "shared" {
|
||||
t.Errorf("Goroutine %d: Expected shared data", id)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestEvent_DataModification_Warning(t *testing.T) {
|
||||
// This test documents the non-thread-safe nature of event data
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "data.test", map[string]any{
|
||||
"mutable": "original",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
// Modifying data after creation (this is allowed but not thread-safe)
|
||||
event.Data["mutable"] = "modified"
|
||||
event.Data["new_key"] = "new_value"
|
||||
|
||||
// Verify modifications are visible
|
||||
if event.Data["mutable"] != "modified" {
|
||||
t.Error("Data modification should be visible")
|
||||
}
|
||||
if event.Data["new_key"] != "new_value" {
|
||||
t.Error("New data should be visible")
|
||||
}
|
||||
|
||||
// CloudEvent should reflect the current state
|
||||
cloudEvent := event.CloudEvent()
|
||||
var parsedData map[string]any
|
||||
json.Unmarshal(cloudEvent.Data, &parsedData)
|
||||
|
||||
if parsedData["mutable"] != "modified" {
|
||||
t.Error("CloudEvent should reflect modified data")
|
||||
}
|
||||
if parsedData["new_key"] != "new_value" {
|
||||
t.Error("CloudEvent should reflect new data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_Aborted_State(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "abort.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
// Initially not aborted
|
||||
if event.Aborted != nil {
|
||||
t.Error("Event should not be aborted initially")
|
||||
}
|
||||
|
||||
// Simulate aborting the event
|
||||
event.Aborted = ErrEventAborted
|
||||
|
||||
if event.Aborted != ErrEventAborted {
|
||||
t.Error("Event should be marked as aborted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrEventAborted_Value(t *testing.T) {
|
||||
if ErrEventAborted == nil {
|
||||
t.Error("ErrEventAborted should not be nil")
|
||||
}
|
||||
|
||||
if ErrEventAborted.Error() != "event aborted" {
|
||||
t.Errorf("Expected 'event aborted', got '%s'", ErrEventAborted.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_UniqueIDs(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
const numEvents = 1000
|
||||
ids := make(map[string]bool)
|
||||
|
||||
for i := 0; i < numEvents; i++ {
|
||||
event, err := NewEvent(ctx, "unique.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event %d: %v", i, err)
|
||||
}
|
||||
|
||||
idStr := event.ID().String()
|
||||
if ids[idStr] {
|
||||
t.Errorf("Duplicate event ID: %s", idStr)
|
||||
}
|
||||
ids[idStr] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_TimestampProgression(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Create events with small delays
|
||||
events := make([]Event, 5)
|
||||
for i := range events {
|
||||
var err error
|
||||
events[i], err = NewEvent(ctx, "time.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event %d: %v", i, err)
|
||||
}
|
||||
|
||||
if i < len(events)-1 {
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify timestamps are in ascending order
|
||||
for i := 1; i < len(events); i++ {
|
||||
if !events[i].Timestamp().After(events[i-1].Timestamp()) {
|
||||
t.Errorf("Event %d timestamp (%v) should be after event %d timestamp (%v)",
|
||||
i, events[i].Timestamp(), i-1, events[i-1].Timestamp())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_JSON_Serialization(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
eventData := map[string]any{
|
||||
"string": "value",
|
||||
"number": 42,
|
||||
"boolean": true,
|
||||
"array": []any{1, 2, 3},
|
||||
"object": map[string]any{"nested": "value"},
|
||||
}
|
||||
|
||||
event, err := NewEvent(ctx, "json.test", eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// CloudEvent should be JSON serializable
|
||||
cloudEventJSON, err := json.Marshal(cloudEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CloudEvent: %v", err)
|
||||
}
|
||||
|
||||
// Should be able to unmarshal back
|
||||
var parsed CloudEvent
|
||||
err = json.Unmarshal(cloudEventJSON, &parsed)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal CloudEvent: %v", err)
|
||||
}
|
||||
|
||||
// Verify key fields survived round-trip
|
||||
if parsed.ID != cloudEvent.ID {
|
||||
t.Errorf("ID mismatch after round-trip")
|
||||
}
|
||||
if parsed.Source != cloudEvent.Source {
|
||||
t.Errorf("Source mismatch after round-trip")
|
||||
}
|
||||
if parsed.Type != cloudEvent.Type {
|
||||
t.Errorf("Type mismatch after round-trip")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_EmptyData(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Test with empty map
|
||||
event1, err := NewEvent(ctx, "empty.map", map[string]any{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event with empty map: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent1 := event1.CloudEvent()
|
||||
var parsed1 map[string]any
|
||||
json.Unmarshal(cloudEvent1.Data, &parsed1)
|
||||
if len(parsed1) != 0 {
|
||||
t.Error("Expected empty data map")
|
||||
}
|
||||
|
||||
// Test with nil data
|
||||
event2, err := NewEvent(ctx, "nil.data", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event with nil data: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent2 := event2.CloudEvent()
|
||||
if cloudEvent2.Data == nil {
|
||||
t.Error("CloudEvent data should not be nil even with nil input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_Origin_WithModule(t *testing.T) {
|
||||
mockMod := &mockEventModule{}
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Set module in ancestry
|
||||
ctx.ancestry = []Module{mockMod}
|
||||
|
||||
event, err := NewEvent(ctx, "module.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
if event.Origin() != mockMod {
|
||||
t.Error("Expected event origin to be the mock module")
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
expectedSource := string(mockMod.CaddyModule().ID)
|
||||
if cloudEvent.Source != expectedSource {
|
||||
t.Errorf("Expected source '%s', got '%s'", expectedSource, cloudEvent.Source)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_LargeData(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Create event with large data
|
||||
largeData := make(map[string]any)
|
||||
for i := 0; i < 1000; i++ {
|
||||
largeData[fmt.Sprintf("key%d", i)] = fmt.Sprintf("value%d", i)
|
||||
}
|
||||
|
||||
event, err := NewEvent(ctx, "large.data", largeData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event with large data: %v", err)
|
||||
}
|
||||
|
||||
// CloudEvent should handle large data
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
var parsedData map[string]any
|
||||
err = json.Unmarshal(cloudEvent.Data, &parsedData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse large data in CloudEvent: %v", err)
|
||||
}
|
||||
|
||||
if len(parsedData) != len(largeData) {
|
||||
t.Errorf("Expected %d data items, got %d", len(largeData), len(parsedData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_SpecialCharacters_InData(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
specialData := map[string]any{
|
||||
"unicode": "🚀✨",
|
||||
"newlines": "line1\nline2\r\nline3",
|
||||
"quotes": `"double" and 'single' quotes`,
|
||||
"backslashes": "\\path\\to\\file",
|
||||
"json_chars": `{"key": "value"}`,
|
||||
"empty": "",
|
||||
"null_value": nil,
|
||||
}
|
||||
|
||||
event, err := NewEvent(ctx, "special.chars", specialData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// Should produce valid JSON
|
||||
var parsedData map[string]any
|
||||
err = json.Unmarshal(cloudEvent.Data, &parsedData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse data with special characters: %v", err)
|
||||
}
|
||||
|
||||
// Verify some special cases survived JSON round-trip
|
||||
if parsedData["unicode"] != "🚀✨" {
|
||||
t.Error("Unicode characters should survive JSON encoding")
|
||||
}
|
||||
|
||||
if parsedData["quotes"] != `"double" and 'single' quotes` {
|
||||
t.Error("Quotes should be properly escaped in JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_ConcurrentCreation(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
const numGoroutines = 100
|
||||
var wg sync.WaitGroup
|
||||
events := make([]Event, numGoroutines)
|
||||
errors := make([]error, numGoroutines)
|
||||
|
||||
// Create events concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
|
||||
eventData := map[string]any{
|
||||
"goroutine": index,
|
||||
"timestamp": time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
events[index], errors[index] = NewEvent(ctx, "concurrent.test", eventData)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all events were created successfully
|
||||
ids := make(map[string]bool)
|
||||
for i, event := range events {
|
||||
if errors[i] != nil {
|
||||
t.Errorf("Goroutine %d: Failed to create event: %v", i, errors[i])
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify unique IDs
|
||||
idStr := event.ID().String()
|
||||
if ids[idStr] {
|
||||
t.Errorf("Duplicate event ID: %s", idStr)
|
||||
}
|
||||
ids[idStr] = true
|
||||
|
||||
// Verify data integrity
|
||||
if goroutineID, exists := event.Data["goroutine"]; !exists || goroutineID != i {
|
||||
t.Errorf("Event %d: Data corruption detected", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mock module for event testing
|
||||
type mockEventModule struct{}
|
||||
|
||||
func (m *mockEventModule) CaddyModule() ModuleInfo {
|
||||
return ModuleInfo{
|
||||
ID: "test.event.module",
|
||||
New: func() Module { return new(mockEventModule) },
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_TimeAccuracy(t *testing.T) {
|
||||
before := time.Now()
|
||||
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "time.accuracy", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
after := time.Now()
|
||||
eventTime := event.Timestamp()
|
||||
|
||||
// Event timestamp should be between before and after
|
||||
if eventTime.Before(before) || eventTime.After(after) {
|
||||
t.Errorf("Event timestamp %v should be between %v and %v", eventTime, before, after)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewEvent(b *testing.B) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
eventData := map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
"key3": true,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewEvent(ctx, "benchmark.test", eventData)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEvent_CloudEvent(b *testing.B) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, _ := NewEvent(ctx, "benchmark.cloud", map[string]any{
|
||||
"data": "test",
|
||||
"num": 123,
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
event.CloudEvent()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEvent_CloudEvent_LargeData(b *testing.B) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Create event with substantial data
|
||||
largeData := make(map[string]any)
|
||||
for i := 0; i < 100; i++ {
|
||||
largeData[fmt.Sprintf("key%d", i)] = fmt.Sprintf("value%d", i)
|
||||
}
|
||||
|
||||
event, _ := NewEvent(ctx, "benchmark.large", largeData)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
event.CloudEvent()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !windows
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFastAbs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
checkFunc func(result string, err error) error
|
||||
}{
|
||||
{
|
||||
name: "absolute path",
|
||||
input: "/usr/local/bin",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if result != "/usr/local/bin" {
|
||||
t.Errorf("expected /usr/local/bin, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "absolute path with dots",
|
||||
input: "/usr/local/../bin",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if result != "/usr/bin" {
|
||||
t.Errorf("expected /usr/bin, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "relative path",
|
||||
input: "relative/path",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
if !strings.HasSuffix(result, "relative/path") {
|
||||
t.Errorf("expected path to end with 'relative/path', got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dot",
|
||||
input: ".",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dot dot",
|
||||
input: "..",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
// Empty string should resolve to current directory
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex relative path",
|
||||
input: "./foo/../bar/./baz",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
if !strings.HasSuffix(result, "bar/baz") {
|
||||
t.Errorf("expected path to end with 'bar/baz', got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := FastAbs(tt.input)
|
||||
tt.checkFunc(result, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastAbsVsFilepathAbs compares FastAbs with filepath.Abs to ensure consistent behavior
|
||||
func TestFastAbsVsFilepathAbs(t *testing.T) {
|
||||
// Skip if working directory cannot be determined
|
||||
if wderr != nil {
|
||||
t.Skip("working directory error, skipping comparison test")
|
||||
}
|
||||
|
||||
testPaths := []string{
|
||||
".",
|
||||
"..",
|
||||
"foo",
|
||||
"foo/bar",
|
||||
"./foo",
|
||||
"../foo",
|
||||
"/absolute/path",
|
||||
"/usr/local/bin",
|
||||
}
|
||||
|
||||
for _, path := range testPaths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
fast, fastErr := FastAbs(path)
|
||||
std, stdErr := filepath.Abs(path)
|
||||
|
||||
// Both should succeed or fail together
|
||||
if (fastErr != nil) != (stdErr != nil) {
|
||||
t.Errorf("error mismatch: FastAbs=%v, filepath.Abs=%v", fastErr, stdErr)
|
||||
}
|
||||
|
||||
// If both succeed, results should be the same
|
||||
if fastErr == nil && stdErr == nil && fast != std {
|
||||
t.Errorf("result mismatch for %q: FastAbs=%s, filepath.Abs=%s", path, fast, std)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastAbsErrorHandling tests error handling when working directory is unavailable
|
||||
func TestFastAbsErrorHandling(t *testing.T) {
|
||||
// This tests the cached wderr behavior
|
||||
if wderr != nil {
|
||||
// Test that FastAbs properly returns the cached error for relative paths
|
||||
_, err := FastAbs("relative/path")
|
||||
if err == nil {
|
||||
t.Error("expected error for relative path when working directory is unavailable")
|
||||
}
|
||||
if err != wderr {
|
||||
t.Errorf("expected cached wderr, got different error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFastAbs benchmarks FastAbs
|
||||
func BenchmarkFastAbs(b *testing.B) {
|
||||
paths := []string{
|
||||
"relative/path",
|
||||
"/absolute/path",
|
||||
".",
|
||||
"..",
|
||||
"./foo/bar",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
FastAbs(paths[i%len(paths)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFastAbsVsStdLib compares performance of FastAbs vs filepath.Abs
|
||||
func BenchmarkFastAbsVsStdLib(b *testing.B) {
|
||||
path := "relative/path/to/file"
|
||||
|
||||
b.Run("FastAbs", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
FastAbs(path)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("filepath.Abs", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
filepath.Abs(path)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,351 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mock filesystem implementation for testing
|
||||
type mockFileSystem struct {
|
||||
name string
|
||||
files map[string]string
|
||||
}
|
||||
|
||||
func (m *mockFileSystem) Open(name string) (fs.File, error) {
|
||||
if content, exists := m.files[name]; exists {
|
||||
return &mockFile{name: name, content: content}, nil
|
||||
}
|
||||
return nil, fs.ErrNotExist
|
||||
}
|
||||
|
||||
type mockFile struct {
|
||||
name string
|
||||
content string
|
||||
pos int
|
||||
}
|
||||
|
||||
func (m *mockFile) Stat() (fs.FileInfo, error) {
|
||||
return &mockFileInfo{name: m.name, size: int64(len(m.content))}, nil
|
||||
}
|
||||
|
||||
func (m *mockFile) Read(b []byte) (int, error) {
|
||||
if m.pos >= len(m.content) {
|
||||
return 0, fs.ErrClosed
|
||||
}
|
||||
n := copy(b, m.content[m.pos:])
|
||||
m.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *mockFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockFileInfo struct {
|
||||
name string
|
||||
size int64
|
||||
}
|
||||
|
||||
func (m *mockFileInfo) Name() string { return m.name }
|
||||
func (m *mockFileInfo) Size() int64 { return m.size }
|
||||
func (m *mockFileInfo) Mode() fs.FileMode { return 0o644 }
|
||||
func (m *mockFileInfo) ModTime() time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
func (m *mockFileInfo) IsDir() bool { return false }
|
||||
func (m *mockFileInfo) Sys() any { return nil }
|
||||
|
||||
// Mock FileSystems implementation for testing
|
||||
type mockFileSystems struct {
|
||||
mu sync.RWMutex
|
||||
filesystems map[string]fs.FS
|
||||
defaultFS fs.FS
|
||||
}
|
||||
|
||||
func newMockFileSystems() *mockFileSystems {
|
||||
return &mockFileSystems{
|
||||
filesystems: make(map[string]fs.FS),
|
||||
defaultFS: &mockFileSystem{name: "default", files: map[string]string{"default.txt": "default content"}},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockFileSystems) Register(k string, v fs.FS) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.filesystems[k] = v
|
||||
}
|
||||
|
||||
func (m *mockFileSystems) Unregister(k string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.filesystems, k)
|
||||
}
|
||||
|
||||
func (m *mockFileSystems) Get(k string) (fs.FS, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
v, ok := m.filesystems[k]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *mockFileSystems) Default() fs.FS {
|
||||
return m.defaultFS
|
||||
}
|
||||
|
||||
func TestFileSystems_Register_Get(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
mockFS := &mockFileSystem{
|
||||
name: "test",
|
||||
files: map[string]string{"test.txt": "test content"},
|
||||
}
|
||||
|
||||
// Register filesystem
|
||||
fsys.Register("test", mockFS)
|
||||
|
||||
// Retrieve filesystem
|
||||
retrieved, exists := fsys.Get("test")
|
||||
if !exists {
|
||||
t.Error("Expected filesystem to exist after registration")
|
||||
}
|
||||
if retrieved != mockFS {
|
||||
t.Error("Retrieved filesystem is not the same as registered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Unregister(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
mockFS := &mockFileSystem{name: "test"}
|
||||
|
||||
// Register then unregister
|
||||
fsys.Register("test", mockFS)
|
||||
fsys.Unregister("test")
|
||||
|
||||
// Should not exist after unregistration
|
||||
_, exists := fsys.Get("test")
|
||||
if exists {
|
||||
t.Error("Filesystem should not exist after unregistration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Default(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
|
||||
defaultFS := fsys.Default()
|
||||
if defaultFS == nil {
|
||||
t.Error("Default filesystem should not be nil")
|
||||
}
|
||||
|
||||
// Test that default filesystem works
|
||||
file, err := defaultFS.Open("default.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open default file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data := make([]byte, 100)
|
||||
n, err := file.Read(data)
|
||||
if err != nil && err != fs.ErrClosed {
|
||||
t.Fatalf("Failed to read default file: %v", err)
|
||||
}
|
||||
|
||||
content := string(data[:n])
|
||||
if content != "default content" {
|
||||
t.Errorf("Expected 'default content', got '%s'", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Concurrent_Access(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
|
||||
const numGoroutines = 50
|
||||
const numOperations = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent register/unregister/get operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
key := fmt.Sprintf("fs-%d", id)
|
||||
mockFS := &mockFileSystem{
|
||||
name: key,
|
||||
files: map[string]string{key + ".txt": "content"},
|
||||
}
|
||||
|
||||
for j := 0; j < numOperations; j++ {
|
||||
// Register
|
||||
fsys.Register(key, mockFS)
|
||||
|
||||
// Get
|
||||
retrieved, exists := fsys.Get(key)
|
||||
if !exists {
|
||||
t.Errorf("Filesystem %s should exist", key)
|
||||
continue
|
||||
}
|
||||
if retrieved != mockFS {
|
||||
t.Errorf("Retrieved filesystem for %s is not correct", key)
|
||||
}
|
||||
|
||||
// Test file access
|
||||
file, err := retrieved.Open(key + ".txt")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to open file in %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
file.Close()
|
||||
|
||||
// Unregister
|
||||
fsys.Unregister(key)
|
||||
|
||||
// Should not exist after unregister
|
||||
_, stillExists := fsys.Get(key)
|
||||
if stillExists {
|
||||
t.Errorf("Filesystem %s should not exist after unregister", key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestFileSystems_Get_NonExistent(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
|
||||
_, exists := fsys.Get("non-existent")
|
||||
if exists {
|
||||
t.Error("Non-existent filesystem should not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Register_Overwrite(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
key := "overwrite-test"
|
||||
|
||||
// Register first filesystem
|
||||
fs1 := &mockFileSystem{name: "fs1"}
|
||||
fsys.Register(key, fs1)
|
||||
|
||||
// Register second filesystem with same key (should overwrite)
|
||||
fs2 := &mockFileSystem{name: "fs2"}
|
||||
fsys.Register(key, fs2)
|
||||
|
||||
// Should get the second filesystem
|
||||
retrieved, exists := fsys.Get(key)
|
||||
if !exists {
|
||||
t.Error("Filesystem should exist")
|
||||
}
|
||||
if retrieved != fs2 {
|
||||
t.Error("Should get the overwritten filesystem")
|
||||
}
|
||||
if retrieved == fs1 {
|
||||
t.Error("Should not get the original filesystem")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Concurrent_RegisterUnregister_SameKey(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
key := "concurrent-key"
|
||||
|
||||
const numGoroutines = 20
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Half the goroutines register, half unregister
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
if i%2 == 0 {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
mockFS := &mockFileSystem{name: fmt.Sprintf("fs-%d", id)}
|
||||
fsys.Register(key, mockFS)
|
||||
}(i)
|
||||
} else {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
fsys.Unregister(key)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// The final state is unpredictable due to race conditions,
|
||||
// but the operations should not panic or cause corruption
|
||||
// Test passes if we reach here without issues
|
||||
}
|
||||
|
||||
func TestFileSystems_StressTest(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping stress test in short mode")
|
||||
}
|
||||
|
||||
fsys := newMockFileSystems()
|
||||
|
||||
const numGoroutines = 100
|
||||
const duration = 100 * time.Millisecond
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Start timer
|
||||
go func() {
|
||||
time.Sleep(duration)
|
||||
close(stopChan)
|
||||
}()
|
||||
|
||||
// Stress test with continuous operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
key := fmt.Sprintf("stress-fs-%d", id%10) // Use limited set of keys
|
||||
mockFS := &mockFileSystem{
|
||||
name: key,
|
||||
files: map[string]string{key + ".txt": "stress content"},
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
default:
|
||||
// Rapid register/get/unregister cycles
|
||||
fsys.Register(key, mockFS)
|
||||
|
||||
if retrieved, exists := fsys.Get(key); exists {
|
||||
// Try to use the filesystem
|
||||
if file, err := retrieved.Open(key + ".txt"); err == nil {
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
fsys.Unregister(key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Test passes if we reach here without panics or deadlocks
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/caddyserver/caddy/v2
|
||||
|
||||
go 1.25.0
|
||||
go 1.25.1
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.6.0
|
||||
@@ -32,27 +32,27 @@ require (
|
||||
github.com/yuin/goldmark-highlighting/v2 v2.0.0-20230729083705-37449abec8cc
|
||||
go.opentelemetry.io/contrib/bridges/prometheus v0.68.0
|
||||
go.opentelemetry.io/contrib/exporters/autoexport v0.65.0
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.65.0
|
||||
go.opentelemetry.io/otel v1.43.0
|
||||
go.opentelemetry.io/otel/sdk v1.43.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0
|
||||
go.step.sm/crypto v0.77.1
|
||||
go.step.sm/crypto v0.81.0
|
||||
go.uber.org/automaxprocs v1.6.0
|
||||
go.uber.org/zap v1.27.1
|
||||
go.uber.org/zap/exp v0.3.0
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/crypto v0.51.0
|
||||
golang.org/x/crypto/x509roots/fallback v0.0.0-20260213171211-a408498e5541
|
||||
golang.org/x/net v0.53.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/term v0.43.0
|
||||
golang.org/x/time v0.15.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
cel.dev/expr v0.25.1 // indirect
|
||||
cloud.google.com/go/auth v0.18.2 // indirect
|
||||
cloud.google.com/go/auth v0.20.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
dario.cat/mergo v1.0.2 // indirect
|
||||
@@ -63,14 +63,14 @@ require (
|
||||
github.com/coreos/go-oidc/v3 v3.17.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
|
||||
github.com/go-jose/go-jose/v3 v3.0.4 // indirect
|
||||
github.com/go-jose/go-jose/v3 v3.0.5 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.4 // indirect
|
||||
github.com/google/certificate-transparency-go v1.1.8-0.20240110162603-74a5dd331745 // indirect
|
||||
github.com/google/go-tpm v0.9.8 // indirect
|
||||
github.com/google/go-tspi v0.3.0 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.18.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
|
||||
github.com/jackc/pgx/v5 v5.9.2 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
@@ -109,9 +109,9 @@ require (
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
google.golang.org/api v0.271.0 // indirect
|
||||
google.golang.org/api v0.277.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 // indirect
|
||||
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 // indirect
|
||||
)
|
||||
|
||||
@@ -169,10 +169,10 @@ require (
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/mod v0.35.0 // indirect
|
||||
golang.org/x/sys v0.43.0
|
||||
golang.org/x/text v0.36.0 // indirect
|
||||
golang.org/x/sys v0.44.0
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
golang.org/x/tools v0.44.0 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/grpc v1.81.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
howett.net/plist v1.0.0 // indirect
|
||||
)
|
||||
|
||||
@@ -2,18 +2,18 @@ cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4=
|
||||
cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
|
||||
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc=
|
||||
cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU=
|
||||
cloud.google.com/go/kms v1.26.0 h1:cK9mN2cf+9V63D3H1f6koxTatWy39aTI/hCjz1I+adU=
|
||||
cloud.google.com/go/kms v1.26.0/go.mod h1:pHKOdFJm63hxBsiPkYtowZPltu9dW0MWvBa6IA4HM58=
|
||||
cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8=
|
||||
cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk=
|
||||
cloud.google.com/go/iam v1.7.0 h1:JD3zh0C6LHl16aCn5Akff0+GELdp1+4hmh6ndoFLl8U=
|
||||
cloud.google.com/go/iam v1.7.0/go.mod h1:tetWZW1PD/m6vcuY2Zj/aU0eCHNPuxedbnbRTyKXvdY=
|
||||
cloud.google.com/go/kms v1.31.0 h1:LS8N92OxFDgOLg5NCo3OmbvjtQAIVT5gUHVLKIDHaFE=
|
||||
cloud.google.com/go/kms v1.31.0/go.mod h1:YIyXZym11R5uovJJt4oN5eUL3oPmirF3yKeIh6QAf4U=
|
||||
cloud.google.com/go/longrunning v0.9.0 h1:0EzbDEGsAvOZNbqXopgniY0w0a1phvu5IdUFq8grmqY=
|
||||
cloud.google.com/go/longrunning v0.9.0/go.mod h1:pkTz846W7bF4o2SzdWJ40Hu0Re+UoNT6Q5t+igIcb8E=
|
||||
code.pfad.fr/check v1.1.0 h1:GWvjdzhSEgHvEHe2uJujDcpmZoySKuHQNrZMfzfO0bE=
|
||||
code.pfad.fr/check v1.1.0/go.mod h1:NiUH13DtYsb7xp5wll0U4SXx7KhXQVCtRgdC96IPfoM=
|
||||
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||
@@ -53,36 +53,36 @@ github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmO
|
||||
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
|
||||
github.com/aryann/difflib v0.0.0-20210328193216-ff5ff6dc229b h1:uUXgbcPDK3KpW29o4iy7GtuappbWT0l5NaMo9H9pJDw=
|
||||
github.com/aryann/difflib v0.0.0-20210328193216-ff5ff6dc229b/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.50.3 h1:s/zDSG/a/Su9aX+v0Ld9cimUCdkr5FWPmBV8owaEbZY=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.50.3/go.mod h1:/iSgiUor15ZuxFGQSTf3lA2FmKxFsQoc2tADOarQBSw=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.51.1 h1:zuSf4olLKZW8cF/W9Y5wvGT+/0raY/3kVp49KsGs0QY=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.51.1/go.mod h1:Y0+uxvxz6ib4KktRdK0V4X45Vcs/JyYoz8H71pO8xeI=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
|
||||
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
|
||||
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/caddyserver/certmagic v0.25.3 h1:mGf5ba8F7xA4c5jfDZZbK2buY1VEkbnwpMDixaju94A=
|
||||
@@ -149,8 +149,8 @@ github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sa
|
||||
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||
github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY=
|
||||
github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
|
||||
github.com/go-jose/go-jose/v3 v3.0.5 h1:BLLJWbC4nMZOfuPVxoZIxeYsn6Nl2r1fITaJ78UQlVQ=
|
||||
github.com/go-jose/go-jose/v3 v3.0.5/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -179,18 +179,18 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo=
|
||||
github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
|
||||
github.com/google/go-tpm-tools v0.4.7 h1:J3ycC8umYxM9A4eF73EofRZu4BxY0jjQnUnkhIBbvws=
|
||||
github.com/google/go-tpm-tools v0.4.7/go.mod h1:gSyXTZHe3fgbzb6WEGd90QucmsnT1SRdlye82gH8QjQ=
|
||||
github.com/google/go-tpm-tools v0.4.8 h1:V4oIYyAD3BykOycwYQzO29WefDouQMTsYZqmG3HxOfM=
|
||||
github.com/google/go-tpm-tools v0.4.8/go.mod h1:4DfiOtiS1KppJjwf1+tqtW4K3PrCJjAAqFKj/TYTJKg=
|
||||
github.com/google/go-tspi v0.3.0 h1:ADtq8RKfP+jrTyIWIZDIYcKOMecRqNJFOew2IT0Inus=
|
||||
github.com/google/go-tspi v0.3.0/go.mod h1:xfMGI3G0PhxCdNVcYr1C4C+EizojDg/TXuX5by8CiHI=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.18.0 h1:jxP5Uuo3bxm3M6gGtV94P4lliVetoCB4Wk2x8QA86LI=
|
||||
github.com/googleapis/gax-go/v2 v2.18.0/go.mod h1:uSzZN4a356eRG985CzJ3WfbFSpqkLTjsnhWGJR6EwrE=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 h1:xolVQTEXusUcAA5UgtyRLjelpFFHWlPQ4XfWGc7MBas=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 h1:PjIWBpgGIVKGoCXuiCoP64altEJCj3/Ei+kSU5vlZD4=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0/go.mod h1:irWBbALSr0Sk3qlqb9SyJ1h68WjgeFuiOzI4Rqw5+aY=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
@@ -377,10 +377,10 @@ go.opentelemetry.io/contrib/bridges/prometheus v0.68.0 h1:w3zlHYETbDwXyWHZlyyR58
|
||||
go.opentelemetry.io/contrib/bridges/prometheus v0.68.0/go.mod h1:GR/mClR2nn7vE8RLwxKjoBNg+QtgdDhRzxVa93koy5o=
|
||||
go.opentelemetry.io/contrib/exporters/autoexport v0.65.0 h1:2gApdml7SznX9szEKFjKjM4qGcGSvAybYLBY319XG3g=
|
||||
go.opentelemetry.io/contrib/exporters/autoexport v0.65.0/go.mod h1:0QqAGlbHXhmPYACG3n5hNzO5DnEqqtg4VcK5pr22RI0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.65.0 h1:kTaCycF9Xkm8VBBvH0rJ4wFeRjtIV55Erk3uuVsIs5s=
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.65.0/go.mod h1:rooPzAbXfxMX9fsPJjmOBg2SN4RhFEV8D7cfGK+N3tE=
|
||||
go.opentelemetry.io/contrib/propagators/aws v1.43.0 h1:EwnsB3cXRLAh7/Nr/9rMuGw73nfb3z6uAvVDjRrbeUg=
|
||||
@@ -431,8 +431,8 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||
go.step.sm/crypto v0.77.1 h1:4EEqfKdv0egQ1lqz2RhnU8Jv6QgXZfrgoxWMqJF9aDs=
|
||||
go.step.sm/crypto v0.77.1/go.mod h1:U/SsmEm80mNnfD5WIkbhuW/B1eFp3fgFvdXyDLpU1AQ=
|
||||
go.step.sm/crypto v0.81.0 h1:e+ouzpNt3Xm4dp7HGXhgYB5y4iFik3vh3phHKWmvugU=
|
||||
go.step.sm/crypto v0.81.0/go.mod h1:fsTizqQeASjTXnbv9O00XtRlIuXRkCdoRiJNyXGQujc=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
@@ -456,8 +456,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
||||
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
||||
golang.org/x/crypto/x509roots/fallback v0.0.0-20260213171211-a408498e5541 h1:FmKxj9ocLKn45jiR2jQMwCVhDvaK7fKQFzfuT9GvyK8=
|
||||
golang.org/x/crypto/x509roots/fallback v0.0.0-20260213171211-a408498e5541/go.mod h1:+UoQFNBq2p2wO+Q6ddVtYc25GZ6VNdOMyyrd4nrqrKs=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
@@ -506,8 +506,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
||||
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@@ -517,8 +517,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
@@ -528,8 +528,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -543,16 +543,16 @@ golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.271.0 h1:cIPN4qcUc61jlh7oXu6pwOQqbJW2GqYh5PS6rB2C/JY=
|
||||
google.golang.org/api v0.271.0/go.mod h1:CGT29bhwkbF+i11qkRUJb2KMKqcJ1hdFceEIRd9u64Q=
|
||||
google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d h1:vsOm753cOAMkt76efriTCDKjpCbK18XGHMJHo0JUKhc=
|
||||
google.golang.org/genproto v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:0oz9d7g9QLSdv9/lgbIjowW1JoxMbxmBVNe8i6tORJI=
|
||||
google.golang.org/api v0.277.0 h1:HJfyJUiNeBBUMai7ez8u14wkp/gH/I4wpGbbO9o+cSk=
|
||||
google.golang.org/api v0.277.0/go.mod h1:B9TqLBwJqVjp1mtt7WeoQwWRwvu/400y5lETOql+giQ=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 h1:tEkOQcXgF6dH1G+MVKZrfpYvozGrzb91k6ha7jireSM=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.0 h1:W3G9N3KQf3BU+YuCtGKJk0CmxQNbAISICD/9AORxLIw=
|
||||
google.golang.org/grpc v1.81.0/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 h1:F29+wU6Ee6qgu9TddPgooOdaqsxTMunOoj8KA5yuS5A=
|
||||
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1/go.mod h1:5KF+wpkbTSbGcR9zteSqZV6fqFOWBl4Yde8En8MryZA=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
package filesystems
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
)
|
||||
|
||||
func TestFileSystemMapDefaultKey(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
|
||||
// Empty key should map to default
|
||||
if m.key("") != DefaultFileSystemKey {
|
||||
t.Errorf("empty key should map to %q, got %q", DefaultFileSystemKey, m.key(""))
|
||||
}
|
||||
|
||||
// Non-empty key should be returned as-is
|
||||
if m.key("custom") != "custom" {
|
||||
t.Errorf("non-empty key should be returned as-is, got %q", m.key("custom"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemMapRegisterAndGet(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
testFS := fstest.MapFS{
|
||||
"hello.txt": &fstest.MapFile{Data: []byte("hello")},
|
||||
}
|
||||
|
||||
m.Register("test", testFS)
|
||||
|
||||
got, ok := m.Get("test")
|
||||
if !ok {
|
||||
t.Fatal("expected to find registered filesystem")
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil filesystem")
|
||||
}
|
||||
|
||||
// Verify the filesystem works
|
||||
f, err := got.Open("hello.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Open() error = %v", err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
func TestFileSystemMapGetNonExistent(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
|
||||
_, ok := m.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected Get to return false for nonexistent key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemMapDefault(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
|
||||
d := m.Default()
|
||||
if d == nil {
|
||||
t.Fatal("Default() should never return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemMapGetDefaultLazyInit(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
|
||||
// Getting the default key before any registration should
|
||||
// auto-initialize to DefaultFileSystem
|
||||
got, ok := m.Get(DefaultFileSystemKey)
|
||||
if !ok {
|
||||
t.Fatal("expected default filesystem to be auto-initialized")
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil default filesystem")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemMapUnregister(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
testFS := fstest.MapFS{}
|
||||
|
||||
m.Register("test", testFS)
|
||||
m.Unregister("test")
|
||||
|
||||
_, ok := m.Get("test")
|
||||
if ok {
|
||||
t.Error("expected filesystem to be unregistered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemMapUnregisterDefault(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
customFS := fstest.MapFS{}
|
||||
|
||||
// Override default
|
||||
m.Register("", customFS)
|
||||
// Unregister default should reset to OsFS, not delete
|
||||
m.Unregister("")
|
||||
|
||||
d := m.Default()
|
||||
if d == nil {
|
||||
t.Fatal("unregistering default should reset it, not delete it")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemMapRegisterNil(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
testFS := fstest.MapFS{}
|
||||
|
||||
// Register then register nil (should unregister)
|
||||
m.Register("test", testFS)
|
||||
m.Register("test", nil)
|
||||
|
||||
_, ok := m.Get("test")
|
||||
if ok {
|
||||
t.Error("registering nil should unregister the filesystem")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemMapEmptyKeyIsDefault(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
testFS := fstest.MapFS{
|
||||
"test.txt": &fstest.MapFile{Data: []byte("test")},
|
||||
}
|
||||
|
||||
// Register with empty key should register as default
|
||||
m.Register("", testFS)
|
||||
|
||||
got, ok := m.Get("")
|
||||
if !ok {
|
||||
t.Fatal("expected to find filesystem registered with empty key")
|
||||
}
|
||||
|
||||
// Should also be accessible via default key
|
||||
got2, ok := m.Get(DefaultFileSystemKey)
|
||||
if !ok {
|
||||
t.Fatal("expected to find filesystem via default key")
|
||||
}
|
||||
|
||||
// Both should work
|
||||
if got == nil || got2 == nil {
|
||||
t.Fatal("expected non-nil filesystems")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystemMapGetTrimsWhitespace(t *testing.T) {
|
||||
m := &FileSystemMap{}
|
||||
testFS := fstest.MapFS{}
|
||||
|
||||
m.Register("test", testFS)
|
||||
|
||||
// Get with whitespace-padded key should match
|
||||
got, ok := m.Get("test ")
|
||||
if !ok {
|
||||
t.Fatal("expected Get to trim whitespace from key")
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil filesystem")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOsFSInterfaces(t *testing.T) {
|
||||
var osFS OsFS
|
||||
|
||||
// Verify interface compliance at compile time (already done with var _ checks)
|
||||
// but test that the methods exist and are callable
|
||||
var _ fs.FS = osFS
|
||||
var _ fs.StatFS = osFS
|
||||
var _ fs.GlobFS = osFS
|
||||
var _ fs.ReadDirFS = osFS
|
||||
var _ fs.ReadFileFS = osFS
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"go.uber.org/zap/zaptest/observer"
|
||||
)
|
||||
|
||||
func TestLogBufferCoreEnabled(t *testing.T) {
|
||||
core := NewLogBufferCore(zapcore.InfoLevel)
|
||||
|
||||
if !core.Enabled(zapcore.InfoLevel) {
|
||||
t.Error("expected InfoLevel to be enabled")
|
||||
}
|
||||
if !core.Enabled(zapcore.ErrorLevel) {
|
||||
t.Error("expected ErrorLevel to be enabled")
|
||||
}
|
||||
if core.Enabled(zapcore.DebugLevel) {
|
||||
t.Error("expected DebugLevel to be disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBufferCoreWriteAndFlush(t *testing.T) {
|
||||
core := NewLogBufferCore(zapcore.InfoLevel)
|
||||
|
||||
// Write entries
|
||||
entry1 := zapcore.Entry{Level: zapcore.InfoLevel, Message: "message1"}
|
||||
entry2 := zapcore.Entry{Level: zapcore.WarnLevel, Message: "message2"}
|
||||
|
||||
if err := core.Write(entry1, []zapcore.Field{zap.String("key1", "val1")}); err != nil {
|
||||
t.Fatalf("Write() error = %v", err)
|
||||
}
|
||||
if err := core.Write(entry2, []zapcore.Field{zap.String("key2", "val2")}); err != nil {
|
||||
t.Fatalf("Write() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify entries are buffered
|
||||
if len(core.entries) != 2 {
|
||||
t.Errorf("expected 2 entries, got %d", len(core.entries))
|
||||
}
|
||||
if len(core.fields) != 2 {
|
||||
t.Errorf("expected 2 field sets, got %d", len(core.fields))
|
||||
}
|
||||
|
||||
// Set up an observed logger to capture flushed entries
|
||||
observedCore, logs := observer.New(zapcore.InfoLevel)
|
||||
logger := zap.New(observedCore)
|
||||
|
||||
core.FlushTo(logger)
|
||||
|
||||
// Verify entries were flushed
|
||||
if logs.Len() != 2 {
|
||||
t.Errorf("expected 2 flushed log entries, got %d", logs.Len())
|
||||
}
|
||||
|
||||
// Verify buffer is cleared after flush
|
||||
if len(core.entries) != 0 {
|
||||
t.Errorf("expected entries to be cleared after flush, got %d", len(core.entries))
|
||||
}
|
||||
if len(core.fields) != 0 {
|
||||
t.Errorf("expected fields to be cleared after flush, got %d", len(core.fields))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBufferCoreSync(t *testing.T) {
|
||||
core := NewLogBufferCore(zapcore.InfoLevel)
|
||||
if err := core.Sync(); err != nil {
|
||||
t.Errorf("Sync() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBufferCoreWith(t *testing.T) {
|
||||
core := NewLogBufferCore(zapcore.InfoLevel)
|
||||
|
||||
// With() currently returns the same core (known limitation)
|
||||
result := core.With([]zapcore.Field{zap.String("test", "val")})
|
||||
if result != core {
|
||||
t.Error("With() should return the same core instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBufferCoreCheck(t *testing.T) {
|
||||
core := NewLogBufferCore(zapcore.InfoLevel)
|
||||
|
||||
// Check for enabled level should add core
|
||||
entry := zapcore.Entry{Level: zapcore.InfoLevel, Message: "test"}
|
||||
ce := &zapcore.CheckedEntry{}
|
||||
result := core.Check(entry, ce)
|
||||
if result == nil {
|
||||
t.Error("Check() should return non-nil for enabled level")
|
||||
}
|
||||
|
||||
// Check for disabled level should not add core
|
||||
debugEntry := zapcore.Entry{Level: zapcore.DebugLevel, Message: "test"}
|
||||
ce2 := &zapcore.CheckedEntry{}
|
||||
result2 := core.Check(debugEntry, ce2)
|
||||
// The ce2 should be returned unchanged (no core added)
|
||||
if result2 != ce2 {
|
||||
t.Error("Check() should return unchanged CheckedEntry for disabled level")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBufferCoreEmptyFlush(t *testing.T) {
|
||||
core := NewLogBufferCore(zapcore.InfoLevel)
|
||||
|
||||
// Flushing with no entries should not panic
|
||||
observedCore, logs := observer.New(zapcore.InfoLevel)
|
||||
logger := zap.New(observedCore)
|
||||
|
||||
core.FlushTo(logger)
|
||||
|
||||
if logs.Len() != 0 {
|
||||
t.Errorf("expected 0 flushed entries for empty buffer, got %d", logs.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBufferCoreConcurrentWrites(t *testing.T) {
|
||||
core := NewLogBufferCore(zapcore.InfoLevel)
|
||||
|
||||
done := make(chan struct{})
|
||||
const numWriters = 10
|
||||
const numWrites = 100
|
||||
|
||||
for i := 0; i < numWriters; i++ {
|
||||
go func() {
|
||||
defer func() { done <- struct{}{} }()
|
||||
for j := 0; j < numWrites; j++ {
|
||||
entry := zapcore.Entry{Level: zapcore.InfoLevel, Message: "concurrent"}
|
||||
_ = core.Write(entry, nil)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < numWriters; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
core.mu.Lock()
|
||||
count := len(core.entries)
|
||||
core.mu.Unlock()
|
||||
|
||||
if count != numWriters*numWrites {
|
||||
t.Errorf("expected %d entries, got %d", numWriters*numWrites, count)
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,37 @@ func TestSanitizeMethod(t *testing.T) {
|
||||
{method: "trace", expected: "TRACE"},
|
||||
{method: "UNKNOWN", expected: "OTHER"},
|
||||
{method: strings.Repeat("ohno", 9999), expected: "OTHER"},
|
||||
|
||||
// Test all standard HTTP methods in uppercase
|
||||
{method: "GET", expected: "GET"},
|
||||
{method: "HEAD", expected: "HEAD"},
|
||||
{method: "POST", expected: "POST"},
|
||||
{method: "PUT", expected: "PUT"},
|
||||
{method: "DELETE", expected: "DELETE"},
|
||||
{method: "CONNECT", expected: "CONNECT"},
|
||||
{method: "OPTIONS", expected: "OPTIONS"},
|
||||
{method: "TRACE", expected: "TRACE"},
|
||||
{method: "PATCH", expected: "PATCH"},
|
||||
|
||||
// Test all standard HTTP methods in lowercase
|
||||
{method: "get", expected: "GET"},
|
||||
{method: "head", expected: "HEAD"},
|
||||
{method: "post", expected: "POST"},
|
||||
{method: "put", expected: "PUT"},
|
||||
{method: "delete", expected: "DELETE"},
|
||||
{method: "connect", expected: "CONNECT"},
|
||||
{method: "options", expected: "OPTIONS"},
|
||||
{method: "trace", expected: "TRACE"},
|
||||
{method: "patch", expected: "PATCH"},
|
||||
|
||||
// Test mixed case and non-standard methods
|
||||
{method: "Get", expected: "OTHER"},
|
||||
{method: "gEt", expected: "OTHER"},
|
||||
{method: "UNKNOWN", expected: "OTHER"},
|
||||
{method: "PROPFIND", expected: "OTHER"},
|
||||
{method: "MKCOL", expected: "OTHER"},
|
||||
{method: "", expected: "OTHER"},
|
||||
{method: " ", expected: "OTHER"},
|
||||
}
|
||||
|
||||
for _, d := range tests {
|
||||
@@ -26,3 +57,79 @@ func TestSanitizeMethod(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "zero returns 200",
|
||||
code: 0,
|
||||
expected: "200",
|
||||
},
|
||||
{
|
||||
name: "200 returns 200",
|
||||
code: 200,
|
||||
expected: "200",
|
||||
},
|
||||
{
|
||||
name: "404 returns 404",
|
||||
code: 404,
|
||||
expected: "404",
|
||||
},
|
||||
{
|
||||
name: "500 returns 500",
|
||||
code: 500,
|
||||
expected: "500",
|
||||
},
|
||||
{
|
||||
name: "301 returns 301",
|
||||
code: 301,
|
||||
expected: "301",
|
||||
},
|
||||
{
|
||||
name: "418 teapot returns 418",
|
||||
code: 418,
|
||||
expected: "418",
|
||||
},
|
||||
{
|
||||
name: "999 custom code",
|
||||
code: 999,
|
||||
expected: "999",
|
||||
},
|
||||
{
|
||||
name: "negative code",
|
||||
code: -1,
|
||||
expected: "-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeCode(tt.code)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeCode(%d) = %s; want %s", tt.code, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSanitizeCode benchmarks the SanitizeCode function
|
||||
func BenchmarkSanitizeCode(b *testing.B) {
|
||||
codes := []int{0, 200, 404, 500, 301, 418}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
SanitizeCode(codes[i%len(codes)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSanitizeMethod benchmarks the SanitizeMethod function
|
||||
func BenchmarkSanitizeMethod(b *testing.B) {
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE", "UNKNOWN"}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
SanitizeMethod(methods[i%len(methods)])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPrivateRangesCIDR(t *testing.T) {
|
||||
ranges := PrivateRangesCIDR()
|
||||
|
||||
// Should include standard private IP ranges
|
||||
expected := map[string]bool{
|
||||
"192.168.0.0/16": false,
|
||||
"172.16.0.0/12": false,
|
||||
"10.0.0.0/8": false,
|
||||
"127.0.0.1/8": false,
|
||||
"fd00::/8": false,
|
||||
"::1": false,
|
||||
}
|
||||
|
||||
for _, r := range ranges {
|
||||
if _, ok := expected[r]; ok {
|
||||
expected[r] = true
|
||||
}
|
||||
}
|
||||
|
||||
for cidr, found := range expected {
|
||||
if !found {
|
||||
t.Errorf("expected private range %q not found in PrivateRangesCIDR()", cidr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ranges) < 6 {
|
||||
t.Errorf("expected at least 6 private ranges, got %d", len(ranges))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxSizeSubjectsListForLog(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subjects map[string]struct{}
|
||||
maxToDisplay int
|
||||
wantLen int
|
||||
wantSuffix bool // whether "(and N more...)" is expected
|
||||
}{
|
||||
{
|
||||
name: "empty map",
|
||||
subjects: map[string]struct{}{},
|
||||
maxToDisplay: 5,
|
||||
wantLen: 0,
|
||||
wantSuffix: false,
|
||||
},
|
||||
{
|
||||
name: "fewer than max",
|
||||
subjects: map[string]struct{}{
|
||||
"example.com": {},
|
||||
"example.org": {},
|
||||
},
|
||||
maxToDisplay: 5,
|
||||
wantLen: 2,
|
||||
wantSuffix: false,
|
||||
},
|
||||
{
|
||||
name: "equal to max",
|
||||
subjects: map[string]struct{}{
|
||||
"a.com": {},
|
||||
"b.com": {},
|
||||
"c.com": {},
|
||||
},
|
||||
maxToDisplay: 3,
|
||||
wantLen: 3,
|
||||
wantSuffix: false,
|
||||
},
|
||||
{
|
||||
name: "more than max",
|
||||
subjects: map[string]struct{}{
|
||||
"a.com": {},
|
||||
"b.com": {},
|
||||
"c.com": {},
|
||||
"d.com": {},
|
||||
"e.com": {},
|
||||
},
|
||||
maxToDisplay: 2,
|
||||
wantLen: 3, // 2 domains + suffix
|
||||
wantSuffix: true,
|
||||
},
|
||||
{
|
||||
name: "max is zero",
|
||||
subjects: map[string]struct{}{
|
||||
"a.com": {},
|
||||
"b.com": {},
|
||||
},
|
||||
maxToDisplay: 0,
|
||||
// BUG: When maxToDisplay is 0, code still appends one domain
|
||||
// because append happens before the break check in the loop.
|
||||
// Expected behavior: 1 item (just suffix). Actual: 2 items
|
||||
// (1 leaked domain + suffix).
|
||||
wantLen: 2,
|
||||
wantSuffix: true,
|
||||
},
|
||||
{
|
||||
name: "single subject with max 1",
|
||||
subjects: map[string]struct{}{
|
||||
"example.com": {},
|
||||
},
|
||||
maxToDisplay: 1,
|
||||
wantLen: 1,
|
||||
wantSuffix: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := MaxSizeSubjectsListForLog(tt.subjects, tt.maxToDisplay)
|
||||
if len(result) != tt.wantLen {
|
||||
t.Errorf("MaxSizeSubjectsListForLog() returned %d items, want %d; got: %v", len(result), tt.wantLen, result)
|
||||
}
|
||||
if tt.wantSuffix {
|
||||
last := result[len(result)-1]
|
||||
if len(last) < 4 || last[:4] != "(and" {
|
||||
t.Errorf("expected suffix '(and N more...)' but got %q", last)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitUnixSocketPermissionsBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantPath string
|
||||
wantFileMode fs.FileMode
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no permission bits defaults to 0200",
|
||||
input: "/run/caddy.sock",
|
||||
wantPath: "/run/caddy.sock",
|
||||
wantFileMode: 0o200,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid permission 0222",
|
||||
input: "/run/caddy.sock|0222",
|
||||
wantPath: "/run/caddy.sock",
|
||||
wantFileMode: 0o222,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid permission 0200",
|
||||
input: "/run/caddy.sock|0200",
|
||||
wantPath: "/run/caddy.sock",
|
||||
wantFileMode: 0o200,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid permission 0777",
|
||||
input: "/run/caddy.sock|0777",
|
||||
wantPath: "/run/caddy.sock",
|
||||
wantFileMode: 0o777,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid permission 0755",
|
||||
input: "/run/caddy.sock|0755",
|
||||
wantPath: "/run/caddy.sock",
|
||||
wantFileMode: 0o755,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid permission 0666",
|
||||
input: "/tmp/test.sock|0666",
|
||||
wantPath: "/tmp/test.sock",
|
||||
wantFileMode: 0o666,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing owner write permission 0444",
|
||||
input: "/run/caddy.sock|0444",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing owner write permission 0044",
|
||||
input: "/run/caddy.sock|0044",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing owner write permission 0100",
|
||||
input: "/run/caddy.sock|0100",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing owner write permission 0500",
|
||||
input: "/run/caddy.sock|0500",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid octal digits",
|
||||
input: "/run/caddy.sock|09ab",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid non-numeric permission",
|
||||
input: "/run/caddy.sock|rwxrwxrwx",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty permission string",
|
||||
input: "/run/caddy.sock|",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "multiple pipes only splits on first",
|
||||
input: "/run/caddy|sock|0222",
|
||||
wantPath: "/run/caddy",
|
||||
wantFileMode: 0, // "sock|0222" is not valid octal
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty path with valid permission",
|
||||
input: "|0222",
|
||||
wantPath: "",
|
||||
wantFileMode: 0o222,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path only with no pipe",
|
||||
input: "/var/run/my-app.sock",
|
||||
wantPath: "/var/run/my-app.sock",
|
||||
wantFileMode: 0o200,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "permission 0300 has write bit",
|
||||
input: "/run/caddy.sock|0300",
|
||||
wantPath: "/run/caddy.sock",
|
||||
wantFileMode: 0o300,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "permission 0422 missing owner write",
|
||||
input: "/run/caddy.sock|0422",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotPath, gotMode, err := SplitUnixSocketPermissionsBits(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SplitUnixSocketPermissionsBits(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if gotPath != tt.wantPath {
|
||||
t.Errorf("SplitUnixSocketPermissionsBits(%q) path = %q, want %q", tt.input, gotPath, tt.wantPath)
|
||||
}
|
||||
if gotMode != tt.wantFileMode {
|
||||
t.Errorf("SplitUnixSocketPermissionsBits(%q) mode = %04o, want %04o", tt.input, gotMode, tt.wantFileMode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+1
-1
@@ -361,7 +361,7 @@ func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort ui
|
||||
if end < start {
|
||||
return NetworkAddress{}, fmt.Errorf("end port must not be less than start port")
|
||||
}
|
||||
if (end - start) > maxPortSpan {
|
||||
if (end-start)+1 > maxPortSpan {
|
||||
return NetworkAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan)
|
||||
}
|
||||
}
|
||||
|
||||
+394
@@ -0,0 +1,394 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
)
|
||||
|
||||
func TestGlobalMetrics_ConfigSuccess(t *testing.T) {
|
||||
// Test setting config success metric
|
||||
originalValue := getMetricValue(globalMetrics.configSuccess)
|
||||
|
||||
// Set to success
|
||||
globalMetrics.configSuccess.Set(1)
|
||||
newValue := getMetricValue(globalMetrics.configSuccess)
|
||||
|
||||
if newValue != 1 {
|
||||
t.Errorf("Expected config success metric to be 1, got %f", newValue)
|
||||
}
|
||||
|
||||
// Set to failure
|
||||
globalMetrics.configSuccess.Set(0)
|
||||
failureValue := getMetricValue(globalMetrics.configSuccess)
|
||||
|
||||
if failureValue != 0 {
|
||||
t.Errorf("Expected config success metric to be 0, got %f", failureValue)
|
||||
}
|
||||
|
||||
// Restore original value if it existed
|
||||
if originalValue != 0 {
|
||||
globalMetrics.configSuccess.Set(originalValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalMetrics_ConfigSuccessTime(t *testing.T) {
|
||||
// Set success time
|
||||
globalMetrics.configSuccessTime.SetToCurrentTime()
|
||||
|
||||
// Get the metric value
|
||||
metricValue := getMetricValue(globalMetrics.configSuccessTime)
|
||||
|
||||
// Should be a reasonable Unix timestamp (not zero)
|
||||
if metricValue == 0 {
|
||||
t.Error("Config success time should not be zero")
|
||||
}
|
||||
|
||||
// Should be recent (within last minute)
|
||||
now := time.Now().Unix()
|
||||
if int64(metricValue) < now-60 || int64(metricValue) > now {
|
||||
t.Errorf("Config success time %f should be recent (now: %d)", metricValue, now)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminMetrics_RequestCount(t *testing.T) {
|
||||
// Initialize admin metrics for testing
|
||||
initAdminMetrics()
|
||||
|
||||
labels := prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/config",
|
||||
"method": "GET",
|
||||
"code": "200",
|
||||
}
|
||||
|
||||
// Get initial value
|
||||
initialValue := getCounterValue(adminMetrics.requestCount, labels)
|
||||
|
||||
// Increment counter
|
||||
adminMetrics.requestCount.With(labels).Inc()
|
||||
|
||||
// Verify increment
|
||||
newValue := getCounterValue(adminMetrics.requestCount, labels)
|
||||
if newValue != initialValue+1 {
|
||||
t.Errorf("Expected counter to increment by 1, got %f -> %f", initialValue, newValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminMetrics_RequestErrors(t *testing.T) {
|
||||
// Initialize admin metrics for testing
|
||||
initAdminMetrics()
|
||||
|
||||
labels := prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/test",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
// Get initial value
|
||||
initialValue := getCounterValue(adminMetrics.requestErrors, labels)
|
||||
|
||||
// Increment error counter
|
||||
adminMetrics.requestErrors.With(labels).Inc()
|
||||
|
||||
// Verify increment
|
||||
newValue := getCounterValue(adminMetrics.requestErrors, labels)
|
||||
if newValue != initialValue+1 {
|
||||
t.Errorf("Expected error counter to increment by 1, got %f -> %f", initialValue, newValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetrics_ConcurrentAccess(t *testing.T) {
|
||||
// Initialize admin metrics
|
||||
initAdminMetrics()
|
||||
|
||||
const numGoroutines = 100
|
||||
const incrementsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
labels := prometheus.Labels{
|
||||
"handler": "concurrent",
|
||||
"path": "/concurrent",
|
||||
"method": "GET",
|
||||
"code": "200",
|
||||
}
|
||||
|
||||
initialCount := getCounterValue(adminMetrics.requestCount, labels)
|
||||
|
||||
// Concurrent increments
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < incrementsPerGoroutine; j++ {
|
||||
adminMetrics.requestCount.With(labels).Inc()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify final count
|
||||
finalCount := getCounterValue(adminMetrics.requestCount, labels)
|
||||
expectedIncrement := float64(numGoroutines * incrementsPerGoroutine)
|
||||
|
||||
if finalCount-initialCount != expectedIncrement {
|
||||
t.Errorf("Expected counter to increase by %f, got %f",
|
||||
expectedIncrement, finalCount-initialCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetrics_LabelValidation(t *testing.T) {
|
||||
// Test various label combinations
|
||||
tests := []struct {
|
||||
name string
|
||||
labels prometheus.Labels
|
||||
metric string
|
||||
}{
|
||||
{
|
||||
name: "valid request count labels",
|
||||
labels: prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/api/test",
|
||||
"method": "GET",
|
||||
"code": "200",
|
||||
},
|
||||
metric: "requestCount",
|
||||
},
|
||||
{
|
||||
name: "valid error labels",
|
||||
labels: prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/api/error",
|
||||
"method": "POST",
|
||||
},
|
||||
metric: "requestErrors",
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
labels: prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "",
|
||||
"method": "GET",
|
||||
"code": "404",
|
||||
},
|
||||
metric: "requestCount",
|
||||
},
|
||||
{
|
||||
name: "special characters in path",
|
||||
labels: prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/api/test%20with%20spaces",
|
||||
"method": "PUT",
|
||||
"code": "201",
|
||||
},
|
||||
metric: "requestCount",
|
||||
},
|
||||
}
|
||||
|
||||
initAdminMetrics()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// This should not panic or error
|
||||
switch test.metric {
|
||||
case "requestCount":
|
||||
adminMetrics.requestCount.With(test.labels).Inc()
|
||||
case "requestErrors":
|
||||
adminMetrics.requestErrors.With(test.labels).Inc()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetrics_Initialization_Idempotent(t *testing.T) {
|
||||
// Test that initializing admin metrics multiple times is safe
|
||||
for i := 0; i < 5; i++ {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Iteration %d: initAdminMetrics panicked: %v", i, r)
|
||||
}
|
||||
}()
|
||||
initAdminMetrics()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentHandlerCounter(t *testing.T) {
|
||||
// Create a test counter with the expected labels
|
||||
counter := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "test_counter",
|
||||
Help: "Test counter for instrumentation",
|
||||
},
|
||||
[]string{"code", "method"},
|
||||
)
|
||||
|
||||
// Create instrumented handler
|
||||
testHandler := instrumentHandlerCounter(
|
||||
counter,
|
||||
&mockHTTPHandler{statusCode: 200},
|
||||
)
|
||||
|
||||
// Create test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get initial counter value
|
||||
initialValue := getCounterValue(counter, prometheus.Labels{"code": "200", "method": "GET"})
|
||||
|
||||
// Serve request
|
||||
testHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Verify counter was incremented
|
||||
finalValue := getCounterValue(counter, prometheus.Labels{"code": "200", "method": "GET"})
|
||||
if finalValue != initialValue+1 {
|
||||
t.Errorf("Expected counter to increment by 1, got %f -> %f", initialValue, finalValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentHandlerCounter_ErrorStatus(t *testing.T) {
|
||||
counter := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "test_error_counter",
|
||||
Help: "Test counter for error status",
|
||||
},
|
||||
[]string{"code", "method"},
|
||||
)
|
||||
|
||||
// Test different status codes
|
||||
statusCodes := []int{200, 404, 500, 301, 401}
|
||||
|
||||
for _, status := range statusCodes {
|
||||
t.Run(fmt.Sprintf("status_%d", status), func(t *testing.T) {
|
||||
handler := instrumentHandlerCounter(
|
||||
counter,
|
||||
&mockHTTPHandler{statusCode: status},
|
||||
)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
statusLabels := prometheus.Labels{"code": fmt.Sprintf("%d", status), "method": "GET"}
|
||||
initialValue := getCounterValue(counter, statusLabels)
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
finalValue := getCounterValue(counter, statusLabels)
|
||||
if finalValue != initialValue+1 {
|
||||
t.Errorf("Status %d: Expected counter increment", status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func getMetricValue(gauge prometheus.Gauge) float64 {
|
||||
metric := &dto.Metric{}
|
||||
gauge.Write(metric)
|
||||
return metric.GetGauge().GetValue()
|
||||
}
|
||||
|
||||
func getCounterValue(counter *prometheus.CounterVec, labels prometheus.Labels) float64 {
|
||||
metric, err := counter.GetMetricWith(labels)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
pb := &dto.Metric{}
|
||||
metric.Write(pb)
|
||||
return pb.GetCounter().GetValue()
|
||||
}
|
||||
|
||||
type mockHTTPHandler struct {
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (m *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(m.statusCode)
|
||||
}
|
||||
|
||||
func TestMetrics_Memory_Usage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory test in short mode")
|
||||
}
|
||||
|
||||
// Initialize metrics
|
||||
initAdminMetrics()
|
||||
|
||||
// Create many different label combinations
|
||||
const numLabels = 1000
|
||||
|
||||
for i := 0; i < numLabels; i++ {
|
||||
labels := prometheus.Labels{
|
||||
"handler": fmt.Sprintf("handler_%d", i%10),
|
||||
"path": fmt.Sprintf("/path_%d", i),
|
||||
"method": []string{"GET", "POST", "PUT", "DELETE"}[i%4],
|
||||
"code": []string{"200", "404", "500"}[i%3],
|
||||
}
|
||||
|
||||
adminMetrics.requestCount.With(labels).Inc()
|
||||
|
||||
// Also increment error counter occasionally
|
||||
if i%10 == 0 {
|
||||
errorLabels := prometheus.Labels{
|
||||
"handler": labels["handler"],
|
||||
"path": labels["path"],
|
||||
"method": labels["method"],
|
||||
}
|
||||
adminMetrics.requestErrors.With(errorLabels).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
// Test passes if we don't run out of memory or panic
|
||||
}
|
||||
|
||||
func BenchmarkGlobalMetrics_ConfigSuccess(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
globalMetrics.configSuccess.Set(float64(i % 2))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGlobalMetrics_ConfigSuccessTime(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
globalMetrics.configSuccessTime.SetToCurrentTime()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAdminMetrics_RequestCount_WithLabels(b *testing.B) {
|
||||
initAdminMetrics()
|
||||
|
||||
labels := prometheus.Labels{
|
||||
"handler": "benchmark",
|
||||
"path": "/benchmark",
|
||||
"method": "GET",
|
||||
"code": "200",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
adminMetrics.requestCount.With(labels).Inc()
|
||||
}
|
||||
}
|
||||
@@ -85,8 +85,11 @@ func (e HandlerError) Unwrap() error { return e.Err }
|
||||
// randString returns a string of n random characters.
|
||||
// It is not even remotely secure OR a proper distribution.
|
||||
// But it's good enough for some things. It excludes certain
|
||||
// confusing characters like I, l, 1, 0, O, etc. If sameCase
|
||||
// is true, then uppercase letters are excluded.
|
||||
// confusing characters like I, l, 1, 0, O. If sameCase
|
||||
// is true, then uppercase letters are excluded as well as
|
||||
// the characters l and o. If sameCase is false, both uppercase
|
||||
// and lowercase letters are used, and the characters I, l, 1, 0, O
|
||||
// are excluded.
|
||||
func randString(n int, sameCase bool) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandlerErrorError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err HandlerError
|
||||
contains []string
|
||||
}{
|
||||
{
|
||||
name: "full error",
|
||||
err: HandlerError{
|
||||
ID: "abc123",
|
||||
StatusCode: 404,
|
||||
Err: fmt.Errorf("not found"),
|
||||
Trace: "pkg.Func (file.go:10)",
|
||||
},
|
||||
contains: []string{"abc123", "404", "not found", "pkg.Func"},
|
||||
},
|
||||
{
|
||||
name: "empty error",
|
||||
err: HandlerError{},
|
||||
contains: []string{},
|
||||
},
|
||||
{
|
||||
name: "error with only status code",
|
||||
err: HandlerError{
|
||||
StatusCode: 500,
|
||||
},
|
||||
contains: []string{"500"},
|
||||
},
|
||||
{
|
||||
name: "error with only message",
|
||||
err: HandlerError{
|
||||
Err: fmt.Errorf("something broke"),
|
||||
},
|
||||
contains: []string{"something broke"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.err.Error()
|
||||
for _, needle := range tt.contains {
|
||||
if !strings.Contains(result, needle) {
|
||||
t.Errorf("Error() = %q, should contain %q", result, needle)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerErrorUnwrap(t *testing.T) {
|
||||
originalErr := fmt.Errorf("original error")
|
||||
he := HandlerError{Err: originalErr}
|
||||
|
||||
unwrapped := he.Unwrap()
|
||||
if unwrapped != originalErr {
|
||||
t.Errorf("Unwrap() = %v, want %v", unwrapped, originalErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
t.Run("creates error with ID and trace", func(t *testing.T) {
|
||||
err := fmt.Errorf("test error")
|
||||
he := Error(500, err)
|
||||
|
||||
if he.StatusCode != 500 {
|
||||
t.Errorf("StatusCode = %d, want 500", he.StatusCode)
|
||||
}
|
||||
if he.ID == "" {
|
||||
t.Error("ID should not be empty")
|
||||
}
|
||||
if len(he.ID) != 9 {
|
||||
t.Errorf("ID length = %d, want 9", len(he.ID))
|
||||
}
|
||||
if he.Trace == "" {
|
||||
t.Error("Trace should not be empty")
|
||||
}
|
||||
if he.Err != err {
|
||||
t.Error("Err should be the original error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unwraps existing HandlerError", func(t *testing.T) {
|
||||
inner := HandlerError{
|
||||
ID: "existing_id",
|
||||
StatusCode: 404,
|
||||
Err: fmt.Errorf("not found"),
|
||||
Trace: "existing trace",
|
||||
}
|
||||
|
||||
he := Error(500, inner)
|
||||
|
||||
// Should keep existing ID
|
||||
if he.ID != "existing_id" {
|
||||
t.Errorf("ID = %q, want 'existing_id'", he.ID)
|
||||
}
|
||||
// Should keep existing StatusCode
|
||||
if he.StatusCode != 404 {
|
||||
t.Errorf("StatusCode = %d, want 404 (existing)", he.StatusCode)
|
||||
}
|
||||
// Should keep existing Trace
|
||||
if he.Trace != "existing trace" {
|
||||
t.Errorf("Trace = %q, want 'existing trace'", he.Trace)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fills missing fields in existing HandlerError", func(t *testing.T) {
|
||||
inner := HandlerError{
|
||||
Err: fmt.Errorf("inner error"),
|
||||
// ID, StatusCode, and Trace are all empty
|
||||
}
|
||||
|
||||
he := Error(503, inner)
|
||||
|
||||
if he.ID == "" {
|
||||
t.Error("should fill missing ID")
|
||||
}
|
||||
if he.StatusCode != 503 {
|
||||
t.Errorf("should fill missing StatusCode with %d, got %d", 503, he.StatusCode)
|
||||
}
|
||||
if he.Trace == "" {
|
||||
t.Error("should fill missing Trace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("generates unique IDs", func(t *testing.T) {
|
||||
ids := make(map[string]struct{})
|
||||
for i := 0; i < 100; i++ {
|
||||
he := Error(500, fmt.Errorf("error %d", i))
|
||||
if _, exists := ids[he.ID]; exists {
|
||||
t.Errorf("duplicate ID generated: %s", he.ID)
|
||||
}
|
||||
ids[he.ID] = struct{}{}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorAsHandlerError(t *testing.T) {
|
||||
he := Error(404, fmt.Errorf("not found"))
|
||||
var target HandlerError
|
||||
if !errors.As(he, &target) {
|
||||
t.Error("Error() result should be assertable as HandlerError via errors.As")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerErrorWithWrappedError(t *testing.T) {
|
||||
// Test that errors.As can unwrap a wrapped HandlerError
|
||||
inner := HandlerError{
|
||||
ID: "inner",
|
||||
StatusCode: 404,
|
||||
Err: fmt.Errorf("inner error"),
|
||||
}
|
||||
wrapped := fmt.Errorf("wrapped: %w", inner)
|
||||
|
||||
he := Error(500, wrapped)
|
||||
// Since wrapped contains a HandlerError, it should be unwrapped
|
||||
if he.ID != "inner" {
|
||||
t.Errorf("should unwrap to inner ID 'inner', got %q", he.ID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func TestRandString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
sameCase bool
|
||||
wantLen int
|
||||
checkCase func(string) bool
|
||||
}{
|
||||
{
|
||||
name: "zero length",
|
||||
length: 0,
|
||||
sameCase: false,
|
||||
wantLen: 0,
|
||||
checkCase: func(s string) bool {
|
||||
return s == ""
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "negative length",
|
||||
length: -5,
|
||||
sameCase: false,
|
||||
wantLen: 0,
|
||||
checkCase: func(s string) bool {
|
||||
return s == ""
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single character mixed case",
|
||||
length: 1,
|
||||
sameCase: false,
|
||||
wantLen: 1,
|
||||
checkCase: func(s string) bool {
|
||||
// Should be alphanumeric
|
||||
return len(s) == 1 && (unicode.IsLetter(rune(s[0])) || unicode.IsDigit(rune(s[0])))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single character same case",
|
||||
length: 1,
|
||||
sameCase: true,
|
||||
wantLen: 1,
|
||||
checkCase: func(s string) bool {
|
||||
// Should be lowercase or digit
|
||||
return len(s) == 1 && (unicode.IsLower(rune(s[0])) || unicode.IsDigit(rune(s[0])))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "short string mixed case",
|
||||
length: 5,
|
||||
sameCase: false,
|
||||
wantLen: 5,
|
||||
checkCase: func(s string) bool {
|
||||
// All characters should be alphanumeric
|
||||
for _, c := range s {
|
||||
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "short string same case",
|
||||
length: 5,
|
||||
sameCase: true,
|
||||
wantLen: 5,
|
||||
checkCase: func(s string) bool {
|
||||
// All characters should be lowercase or digits
|
||||
for _, c := range s {
|
||||
if unicode.IsUpper(c) {
|
||||
return false
|
||||
}
|
||||
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "medium string mixed case",
|
||||
length: 20,
|
||||
sameCase: false,
|
||||
wantLen: 20,
|
||||
checkCase: func(s string) bool {
|
||||
for _, c := range s {
|
||||
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "long string same case",
|
||||
length: 100,
|
||||
sameCase: true,
|
||||
wantLen: 100,
|
||||
checkCase: func(s string) bool {
|
||||
for _, c := range s {
|
||||
if unicode.IsUpper(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := randString(tt.length, tt.sameCase)
|
||||
|
||||
// Check length
|
||||
if len(result) != tt.wantLen {
|
||||
t.Errorf("randString(%d, %v) length = %d, want %d",
|
||||
tt.length, tt.sameCase, len(result), tt.wantLen)
|
||||
}
|
||||
|
||||
// Check case requirements
|
||||
if !tt.checkCase(result) {
|
||||
t.Errorf("randString(%d, %v) = %q failed case check",
|
||||
tt.length, tt.sameCase, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRandString_NoConfusingChars ensures that confusing characters
|
||||
// like I, l, 1, 0, O are excluded from the generated strings
|
||||
func TestRandString_NoConfusingChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sameCase bool
|
||||
excluded []rune
|
||||
}{
|
||||
{
|
||||
name: "mixed case excludes I,l,1,0,O",
|
||||
sameCase: false,
|
||||
excluded: []rune{'I', 'l', '1', '0', 'O'},
|
||||
},
|
||||
{
|
||||
name: "same case excludes l,0",
|
||||
sameCase: true,
|
||||
excluded: []rune{'l', 'o'},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate multiple strings to increase confidence
|
||||
for i := 0; i < 100; i++ {
|
||||
result := randString(50, tt.sameCase)
|
||||
|
||||
for _, char := range tt.excluded {
|
||||
if strings.ContainsRune(result, char) {
|
||||
t.Errorf("randString(50, %v) contains excluded character %q in %q",
|
||||
tt.sameCase, char, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRandString_Uniqueness verifies that consecutive calls produce
|
||||
// different strings (with high probability)
|
||||
func TestRandString_Uniqueness(t *testing.T) {
|
||||
const iterations = 100
|
||||
const length = 16
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sameCase bool
|
||||
}{
|
||||
{"mixed case", false},
|
||||
{"same case", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
seen := make(map[string]bool)
|
||||
duplicates := 0
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
result := randString(length, tt.sameCase)
|
||||
if seen[result] {
|
||||
duplicates++
|
||||
}
|
||||
seen[result] = true
|
||||
}
|
||||
|
||||
// With a 16-character string from a large alphabet, duplicates should be extremely rare
|
||||
// Allow at most 1 duplicate in 100 iterations
|
||||
if duplicates > 1 {
|
||||
t.Errorf("randString(%d, %v) produced %d duplicates in %d iterations (expected ≤1)",
|
||||
length, tt.sameCase, duplicates, iterations)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRandString_CharacterDistribution checks that the generated strings
|
||||
// contain a reasonable mix of characters (not just one character)
|
||||
func TestRandString_CharacterDistribution(t *testing.T) {
|
||||
const length = 1000
|
||||
const minUniqueChars = 15 // Should have at least 15 different characters in 1000 chars
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sameCase bool
|
||||
}{
|
||||
{"mixed case", false},
|
||||
{"same case", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := randString(length, tt.sameCase)
|
||||
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, c := range result {
|
||||
uniqueChars[c] = true
|
||||
}
|
||||
|
||||
if len(uniqueChars) < minUniqueChars {
|
||||
t.Errorf("randString(%d, %v) produced only %d unique characters (expected ≥%d)",
|
||||
length, tt.sameCase, len(uniqueChars), minUniqueChars)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRandString measures the performance of random string generation
|
||||
func BenchmarkRandString(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
length int
|
||||
sameCase bool
|
||||
}{
|
||||
{"short_mixed", 8, false},
|
||||
{"short_same", 8, true},
|
||||
{"medium_mixed", 32, false},
|
||||
{"medium_same", 32, true},
|
||||
{"long_mixed", 128, false},
|
||||
{"long_same", 128, true},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = randString(bm.length, bm.sameCase)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCIDRExpressionToPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expr string
|
||||
want netip.Prefix
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid CIDR IPv4",
|
||||
expr: "192.168.0.0/16",
|
||||
want: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
{
|
||||
name: "valid CIDR IPv6",
|
||||
expr: "fd00::/8",
|
||||
want: netip.MustParsePrefix("fd00::/8"),
|
||||
},
|
||||
{
|
||||
name: "single IPv4 becomes /32",
|
||||
expr: "192.168.1.1",
|
||||
want: netip.MustParsePrefix("192.168.1.1/32"),
|
||||
},
|
||||
{
|
||||
name: "single IPv6 becomes /128",
|
||||
expr: "::1",
|
||||
want: netip.MustParsePrefix("::1/128"),
|
||||
},
|
||||
{
|
||||
name: "loopback IPv4",
|
||||
expr: "127.0.0.1",
|
||||
want: netip.MustParsePrefix("127.0.0.1/32"),
|
||||
},
|
||||
{
|
||||
name: "full IPv6 address",
|
||||
expr: "2001:db8::1",
|
||||
want: netip.MustParsePrefix("2001:db8::1/128"),
|
||||
},
|
||||
{
|
||||
name: "invalid CIDR",
|
||||
expr: "192.168.0.0/33",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid IP",
|
||||
expr: "not-an-ip",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
expr: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "CIDR with invalid IP",
|
||||
expr: "999.999.999.999/24",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "CIDR /0 matches everything",
|
||||
expr: "0.0.0.0/0",
|
||||
want: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
},
|
||||
{
|
||||
name: "CIDR /32 single host",
|
||||
expr: "10.0.0.1/32",
|
||||
want: netip.MustParsePrefix("10.0.0.1/32"),
|
||||
},
|
||||
{
|
||||
name: "malformed CIDR with extra slash",
|
||||
expr: "10.0.0.0/8/16",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CIDRExpressionToPrefix(tt.expr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CIDRExpressionToPrefix(%q) error = %v, wantErr %v", tt.expr, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && got != tt.want {
|
||||
t.Errorf("CIDRExpressionToPrefix(%q) = %v, want %v", tt.expr, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticIPRangeProvision(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ranges []string
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid CIDR ranges",
|
||||
ranges: []string{"192.168.0.0/16", "10.0.0.0/8"},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "single IPs",
|
||||
ranges: []string{"192.168.1.1", "10.0.0.1"},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "mixed CIDR and single IP",
|
||||
ranges: []string{"192.168.0.0/16", "10.0.0.1"},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "invalid range",
|
||||
ranges: []string{"not-valid"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty ranges",
|
||||
ranges: []string{},
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "nil ranges",
|
||||
ranges: nil,
|
||||
wantLen: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &StaticIPRange{Ranges: tt.ranges}
|
||||
// We can't easily create a caddy.Context here without full module setup,
|
||||
// but Provision only uses the ranges field, so we test the logic directly.
|
||||
// The Provision method calls CIDRExpressionToPrefix which we test separately.
|
||||
var parsedCount int
|
||||
var gotErr bool
|
||||
for _, r := range s.Ranges {
|
||||
_, err := CIDRExpressionToPrefix(r)
|
||||
if err != nil {
|
||||
gotErr = true
|
||||
break
|
||||
}
|
||||
parsedCount++
|
||||
}
|
||||
|
||||
if gotErr != tt.wantErr {
|
||||
t.Errorf("provision error = %v, wantErr %v", gotErr, tt.wantErr)
|
||||
}
|
||||
if !tt.wantErr && parsedCount != tt.wantLen {
|
||||
t.Errorf("parsed %d ranges, want %d", parsedCount, tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticIPRangeGetIPRanges(t *testing.T) {
|
||||
s := &StaticIPRange{
|
||||
ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
}
|
||||
|
||||
result := s.GetIPRanges(nil) // request is unused
|
||||
if len(result) != 2 {
|
||||
t.Errorf("GetIPRanges() returned %d prefixes, want 2", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticIPRangeCaddyModule(t *testing.T) {
|
||||
s := StaticIPRange{}
|
||||
info := s.CaddyModule()
|
||||
if info.ID != "http.ip_sources.static" {
|
||||
t.Errorf("CaddyModule().ID = %v, want 'http.ip_sources.static'", info.ID)
|
||||
}
|
||||
mod := info.New()
|
||||
if mod == nil {
|
||||
t.Error("New() should not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateRangesCIDRWrapper(t *testing.T) {
|
||||
ranges := PrivateRangesCIDR()
|
||||
if len(ranges) == 0 {
|
||||
t.Error("PrivateRangesCIDR() should return non-empty list")
|
||||
}
|
||||
|
||||
// Verify all ranges are valid CIDR or IP expressions
|
||||
for _, r := range ranges {
|
||||
_, err := CIDRExpressionToPrefix(r)
|
||||
if err != nil {
|
||||
t.Errorf("PrivateRangesCIDR() returned invalid range %q: %v", r, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,316 @@
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestLoggableHTTPRequestMarshal(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "https://example.com/path?q=1", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
req.Header.Set("Accept", "text/html")
|
||||
|
||||
ctx := context.WithValue(req.Context(), VarsCtxKey, map[string]any{
|
||||
ClientIPVarKey: "192.168.1.1",
|
||||
})
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
lr := LoggableHTTPRequest{Request: req}
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
err := lr.MarshalLogObject(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalLogObject() error = %v", err)
|
||||
}
|
||||
|
||||
if enc.Fields["remote_ip"] != "192.168.1.1" {
|
||||
t.Errorf("remote_ip = %v, want '192.168.1.1'", enc.Fields["remote_ip"])
|
||||
}
|
||||
if enc.Fields["remote_port"] != "12345" {
|
||||
t.Errorf("remote_port = %v, want '12345'", enc.Fields["remote_port"])
|
||||
}
|
||||
if enc.Fields["client_ip"] != "192.168.1.1" {
|
||||
t.Errorf("client_ip = %v, want '192.168.1.1'", enc.Fields["client_ip"])
|
||||
}
|
||||
if enc.Fields["method"] != "GET" {
|
||||
t.Errorf("method = %v, want 'GET'", enc.Fields["method"])
|
||||
}
|
||||
if enc.Fields["host"] != "example.com" {
|
||||
t.Errorf("host = %v, want 'example.com'", enc.Fields["host"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggableHTTPRequestNoPort(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "http://example.com/", nil)
|
||||
req.RemoteAddr = "192.168.1.1" // no port
|
||||
|
||||
ctx := context.WithValue(req.Context(), VarsCtxKey, map[string]any{})
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
lr := LoggableHTTPRequest{Request: req}
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
err := lr.MarshalLogObject(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalLogObject() error = %v", err)
|
||||
}
|
||||
|
||||
if enc.Fields["remote_ip"] != "192.168.1.1" {
|
||||
t.Errorf("remote_ip = %v, want '192.168.1.1'", enc.Fields["remote_ip"])
|
||||
}
|
||||
if enc.Fields["remote_port"] != "" {
|
||||
t.Errorf("remote_port = %v, want empty string", enc.Fields["remote_port"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggableHTTPHeaderRedaction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
shouldLogCredentials bool
|
||||
expectRedacted []string
|
||||
}{
|
||||
{
|
||||
name: "redacts sensitive headers",
|
||||
header: http.Header{
|
||||
"Cookie": {"session=abc123"},
|
||||
"Set-Cookie": {"session=xyz"},
|
||||
"Authorization": {"Bearer token123"},
|
||||
"Proxy-Authorization": {"Basic credentials"},
|
||||
"User-Agent": {"test-agent"},
|
||||
},
|
||||
shouldLogCredentials: false,
|
||||
expectRedacted: []string{"Cookie", "Set-Cookie", "Authorization", "Proxy-Authorization"},
|
||||
},
|
||||
{
|
||||
name: "logs credentials when enabled",
|
||||
header: http.Header{
|
||||
"Cookie": {"session=abc123"},
|
||||
"Authorization": {"Bearer token123"},
|
||||
},
|
||||
shouldLogCredentials: true,
|
||||
expectRedacted: nil, // nothing should be redacted
|
||||
},
|
||||
{
|
||||
name: "nil header",
|
||||
header: nil,
|
||||
shouldLogCredentials: false,
|
||||
expectRedacted: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := LoggableHTTPHeader{Header: tt.header, ShouldLogCredentials: tt.shouldLogCredentials}
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
err := h.MarshalLogObject(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalLogObject() error = %v", err)
|
||||
}
|
||||
|
||||
if tt.header == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range tt.expectRedacted {
|
||||
// The encoded value should be an array with ["REDACTED"]
|
||||
if arr, ok := enc.Fields[key]; ok {
|
||||
arrEnc, ok := arr.(zapcore.ArrayMarshaler)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// Marshal the array to check its contents
|
||||
testEnc := &testArrayEncoder{}
|
||||
_ = arrEnc.MarshalLogArray(testEnc)
|
||||
if len(testEnc.items) != 1 || testEnc.items[0] != "REDACTED" {
|
||||
t.Errorf("header %q should be REDACTED, got %v", key, testEnc.items)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tt.shouldLogCredentials && tt.header != nil {
|
||||
for key, vals := range tt.header {
|
||||
if arr, ok := enc.Fields[key]; ok {
|
||||
arrEnc, ok := arr.(zapcore.ArrayMarshaler)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
testEnc := &testArrayEncoder{}
|
||||
_ = arrEnc.MarshalLogArray(testEnc)
|
||||
if len(testEnc.items) > 0 && testEnc.items[0] == "REDACTED" {
|
||||
t.Errorf("header %q should NOT be redacted when credentials logging is enabled, original: %v", key, vals)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testArrayEncoder is a simple array encoder for testing
|
||||
type testArrayEncoder struct {
|
||||
items []string
|
||||
}
|
||||
|
||||
func (e *testArrayEncoder) AppendString(s string) { e.items = append(e.items, s) }
|
||||
func (e *testArrayEncoder) AppendBool(bool) {}
|
||||
func (e *testArrayEncoder) AppendByteString([]byte) {}
|
||||
func (e *testArrayEncoder) AppendComplex128(complex128) {}
|
||||
func (e *testArrayEncoder) AppendComplex64(complex64) {}
|
||||
func (e *testArrayEncoder) AppendFloat64(float64) {}
|
||||
func (e *testArrayEncoder) AppendFloat32(float32) {}
|
||||
func (e *testArrayEncoder) AppendInt(int) {}
|
||||
func (e *testArrayEncoder) AppendInt64(int64) {}
|
||||
func (e *testArrayEncoder) AppendInt32(int32) {}
|
||||
func (e *testArrayEncoder) AppendInt16(int16) {}
|
||||
func (e *testArrayEncoder) AppendInt8(int8) {}
|
||||
func (e *testArrayEncoder) AppendUint(uint) {}
|
||||
func (e *testArrayEncoder) AppendUint64(uint64) {}
|
||||
func (e *testArrayEncoder) AppendUint32(uint32) {}
|
||||
func (e *testArrayEncoder) AppendUint16(uint16) {}
|
||||
func (e *testArrayEncoder) AppendUint8(uint8) {}
|
||||
func (e *testArrayEncoder) AppendUintptr(uintptr) {}
|
||||
func (e *testArrayEncoder) AppendDuration(time.Duration) {}
|
||||
func (e *testArrayEncoder) AppendTime(time.Time) {}
|
||||
func (e *testArrayEncoder) AppendArray(zapcore.ArrayMarshaler) error { return nil }
|
||||
func (e *testArrayEncoder) AppendObject(zapcore.ObjectMarshaler) error { return nil }
|
||||
func (e *testArrayEncoder) AppendReflected(any) error { return nil }
|
||||
|
||||
func TestLoggableStringArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input LoggableStringArray
|
||||
}{
|
||||
{
|
||||
name: "nil array",
|
||||
input: nil,
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
input: LoggableStringArray{},
|
||||
},
|
||||
{
|
||||
name: "single element",
|
||||
input: LoggableStringArray{"hello"},
|
||||
},
|
||||
{
|
||||
name: "multiple elements",
|
||||
input: LoggableStringArray{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
enc := &testArrayEncoder{}
|
||||
err := tt.input.MarshalLogArray(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalLogArray() error = %v", err)
|
||||
}
|
||||
if tt.input != nil && len(enc.items) != len(tt.input) {
|
||||
t.Errorf("expected %d items, got %d", len(tt.input), len(enc.items))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggableTLSConnState(t *testing.T) {
|
||||
t.Run("basic TLS state", func(t *testing.T) {
|
||||
state := LoggableTLSConnState(tls.ConnectionState{
|
||||
Version: tls.VersionTLS13,
|
||||
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||||
NegotiatedProtocol: "h2",
|
||||
ServerName: "example.com",
|
||||
})
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
err := state.MarshalLogObject(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalLogObject() error = %v", err)
|
||||
}
|
||||
|
||||
if enc.Fields["proto"] != "h2" {
|
||||
t.Errorf("proto = %v, want 'h2'", enc.Fields["proto"])
|
||||
}
|
||||
if enc.Fields["server_name"] != "example.com" {
|
||||
t.Errorf("server_name = %v, want 'example.com'", enc.Fields["server_name"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLS state with peer certificates", func(t *testing.T) {
|
||||
// Skipping detailed cert subject test since x509.Certificate creation
|
||||
// for testing requires complex setup; covered by the no-peer-certs test
|
||||
state := LoggableTLSConnState(tls.ConnectionState{
|
||||
Version: tls.VersionTLS12,
|
||||
CipherSuite: tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
})
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
err := state.MarshalLogObject(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalLogObject() error = %v", err)
|
||||
}
|
||||
|
||||
if enc.Fields["version"] != uint16(tls.VersionTLS12) {
|
||||
t.Errorf("version = %v, want TLS 1.2", enc.Fields["version"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLS state without peer certificates", func(t *testing.T) {
|
||||
state := LoggableTLSConnState(tls.ConnectionState{
|
||||
Version: tls.VersionTLS12,
|
||||
})
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
err := state.MarshalLogObject(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalLogObject() error = %v", err)
|
||||
}
|
||||
|
||||
// Should not contain client cert fields when no peer certs
|
||||
if _, ok := enc.Fields["client_common_name"]; ok {
|
||||
t.Error("should not have client_common_name without peer certificates")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoggableHTTPHeaderCaseInsensitivity(t *testing.T) {
|
||||
// HTTP headers should be case-insensitive for redaction
|
||||
h := LoggableHTTPHeader{
|
||||
Header: http.Header{
|
||||
"AUTHORIZATION": {"Bearer secret"},
|
||||
"cookie": {"session=abc"},
|
||||
"Proxy-Authorization": {"Basic creds"},
|
||||
},
|
||||
ShouldLogCredentials: false,
|
||||
}
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
err := h.MarshalLogObject(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("MarshalLogObject() error = %v", err)
|
||||
}
|
||||
|
||||
// All sensitive headers should be redacted regardless of casing
|
||||
// Note: http.Header canonicalizes keys, so "cookie" becomes "Cookie"
|
||||
for key := range enc.Fields {
|
||||
lk := strings.ToLower(key)
|
||||
if lk == "cookie" || lk == "authorization" || lk == "proxy-authorization" {
|
||||
arr, ok := enc.Fields[key].(zapcore.ArrayMarshaler)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
testEnc := &testArrayEncoder{}
|
||||
_ = arr.MarshalLogArray(testEnc)
|
||||
if len(testEnc.items) != 1 || testEnc.items[0] != "REDACTED" {
|
||||
t.Errorf("header %q should be REDACTED, got %v", key, testEnc.items)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -21,8 +21,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
)
|
||||
|
||||
// ResponseWriterWrapper wraps an underlying ResponseWriter and
|
||||
@@ -72,8 +70,6 @@ type responseRecorder struct {
|
||||
size int
|
||||
wroteHeader bool
|
||||
stream bool
|
||||
hijacked bool
|
||||
detached bool
|
||||
|
||||
readSize *int
|
||||
}
|
||||
@@ -148,8 +144,7 @@ func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer
|
||||
|
||||
// WriteHeader writes the headers with statusCode to the wrapped
|
||||
// ResponseWriter unless the response is to be buffered instead.
|
||||
// 1xx responses are never buffered, except 101 which is treated
|
||||
// as a final upgrade response.
|
||||
// 1xx responses are never buffered.
|
||||
func (rr *responseRecorder) WriteHeader(statusCode int) {
|
||||
if rr.wroteHeader {
|
||||
return
|
||||
@@ -166,12 +161,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
|
||||
rr.stream = !rr.shouldBuffer(rr.statusCode, rr.ResponseWriterWrapper.Header())
|
||||
}
|
||||
|
||||
// 1xx responses except 101 aren't final; just informational
|
||||
if statusCode < 100 || statusCode > 199 || statusCode == http.StatusSwitchingProtocols {
|
||||
// 1xx responses aren't final; just informational
|
||||
if statusCode < 100 || statusCode > 199 {
|
||||
rr.wroteHeader = true
|
||||
}
|
||||
|
||||
// if 1xx or not buffered, immediately write header
|
||||
// if informational or not buffered, immediately write header
|
||||
if rr.stream || (100 <= statusCode && statusCode <= 199) {
|
||||
rr.ResponseWriterWrapper.WriteHeader(statusCode)
|
||||
}
|
||||
@@ -227,18 +222,7 @@ func (rr *responseRecorder) Buffered() bool {
|
||||
return !rr.stream
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) DetachAfterHijack(detached bool) bool {
|
||||
if rr.hijacked {
|
||||
return false
|
||||
}
|
||||
rr.detached = detached
|
||||
return true
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) WriteResponse() error {
|
||||
if rr.hijacked {
|
||||
return nil
|
||||
}
|
||||
if rr.statusCode == 0 {
|
||||
// could happen if no handlers actually wrote anything,
|
||||
// and this prevents a panic; status must be > 0
|
||||
@@ -269,25 +253,11 @@ func (rr *responseRecorder) setReadSize(size *int) {
|
||||
}
|
||||
|
||||
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if !rr.wroteHeader {
|
||||
// hijacking without writing status code first works as long as
|
||||
// subsequent writes follows http1.1 wire format, but it will
|
||||
// show up with a status code of 0 in the access log and bytes
|
||||
// written will include response headers. Response headers won't
|
||||
// be present in the log if not set on the response writer.
|
||||
caddy.Log().Warn("hijacking without writing status code first")
|
||||
}
|
||||
//nolint:bodyclose
|
||||
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
rr.hijacked = true
|
||||
rr.stream = true
|
||||
rr.wroteHeader = true
|
||||
if rr.detached {
|
||||
return conn, brw, nil
|
||||
}
|
||||
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
|
||||
conn = &hijackedConn{conn, rr}
|
||||
brw.Writer.Reset(conn)
|
||||
@@ -341,29 +311,6 @@ func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// DetachResponseWriterAfterHijack detaches w or one of its wrapped
|
||||
// response writers when it's hijacked. Returns true if not already
|
||||
// hijacked. When detached, bytes read or written stats will not be
|
||||
// recorded for the hijacked connection, and it's safe to use the
|
||||
// connection after http middleware returns.
|
||||
func DetachResponseWriterAfterHijack(w http.ResponseWriter, detached bool) bool {
|
||||
for w != nil {
|
||||
if detacher, ok := w.(interface{ DetachAfterHijack(bool) bool }); ok {
|
||||
return detacher.DetachAfterHijack(detached)
|
||||
}
|
||||
unwrapper, ok := w.(interface{ Unwrap() http.ResponseWriter })
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
next := unwrapper.Unwrap()
|
||||
if next == w {
|
||||
return false
|
||||
}
|
||||
w = next
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ResponseRecorder is a http.ResponseWriter that records
|
||||
// responses instead of writing them to the client. See
|
||||
// docs for NewResponseRecorder for proper usage.
|
||||
@@ -372,7 +319,6 @@ type ResponseRecorder interface {
|
||||
Status() int
|
||||
Buffer() *bytes.Buffer
|
||||
Buffered() bool
|
||||
DetachAfterHijack(bool) bool
|
||||
Size() int
|
||||
WriteResponse() error
|
||||
}
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type responseWriterSpy interface {
|
||||
@@ -47,50 +44,6 @@ func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
|
||||
|
||||
func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }
|
||||
|
||||
type hijackRespWriter struct {
|
||||
baseRespWriter
|
||||
header http.Header
|
||||
status int
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func newHijackRespWriter() *hijackRespWriter {
|
||||
return &hijackRespWriter{
|
||||
header: make(http.Header),
|
||||
conn: stubConn{},
|
||||
}
|
||||
}
|
||||
|
||||
func (hrw *hijackRespWriter) Header() http.Header {
|
||||
return hrw.header
|
||||
}
|
||||
|
||||
func (hrw *hijackRespWriter) WriteHeader(statusCode int) {
|
||||
hrw.status = statusCode
|
||||
}
|
||||
|
||||
func (hrw *hijackRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
br := bufio.NewReader(hrw.conn)
|
||||
bw := bufio.NewWriter(hrw.conn)
|
||||
return hrw.conn, bufio.NewReadWriter(br, bw), nil
|
||||
}
|
||||
|
||||
type stubConn struct{}
|
||||
|
||||
func (stubConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
||||
func (stubConn) Write(p []byte) (int, error) { return len(p), nil }
|
||||
func (stubConn) Close() error { return nil }
|
||||
func (stubConn) LocalAddr() net.Addr { return stubAddr("local") }
|
||||
func (stubConn) RemoteAddr() net.Addr { return stubAddr("remote") }
|
||||
func (stubConn) SetDeadline(time.Time) error { return nil }
|
||||
func (stubConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (stubConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type stubAddr string
|
||||
|
||||
func (a stubAddr) Network() string { return "tcp" }
|
||||
func (a stubAddr) String() string { return string(a) }
|
||||
|
||||
func TestResponseWriterWrapperReadFrom(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
responseWriter responseWriterSpy
|
||||
@@ -216,49 +169,3 @@ func TestResponseRecorderReadFrom(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRecorderSwitchingProtocolsIsHijackAware(t *testing.T) {
|
||||
w := newHijackRespWriter()
|
||||
var buf bytes.Buffer
|
||||
|
||||
rr := NewResponseRecorder(w, &buf, func(status int, header http.Header) bool {
|
||||
return true
|
||||
})
|
||||
rr.WriteHeader(http.StatusSwitchingProtocols)
|
||||
|
||||
if rr.Status() != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("status = %d, want %d", rr.Status(), http.StatusSwitchingProtocols)
|
||||
}
|
||||
if w.status != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("underlying status = %d, want %d", w.status, http.StatusSwitchingProtocols)
|
||||
}
|
||||
|
||||
hj, ok := rr.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Fatal("response recorder does not implement http.Hijacker")
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("Hijack() error = %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if rr.Buffered() {
|
||||
t.Fatal("hijacked response should not remain buffered")
|
||||
}
|
||||
if rr.DetachAfterHijack(true) {
|
||||
t.Fatal("response recorder should report hijacked state by returning false")
|
||||
}
|
||||
if DetachResponseWriterAfterHijack(rr, true) {
|
||||
t.Fatal("DetachResponseWriterAfterHijack() should report false after hijack")
|
||||
}
|
||||
if err := rr.WriteResponse(); err != nil {
|
||||
t.Fatalf("WriteResponse() after hijack returned error: %v", err)
|
||||
}
|
||||
if rr.Size() != 0 {
|
||||
t.Fatalf("size = %d, want 0 after hijack handshake", rr.Size())
|
||||
}
|
||||
if got := w.Written(); got != "" {
|
||||
t.Fatalf("unexpected buffered body write after hijack: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,12 +99,6 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
|
||||
// stream_buffer_size <size>
|
||||
// stream_timeout <duration>
|
||||
// stream_close_delay <duration>
|
||||
// stream_detached
|
||||
// stream_logs {
|
||||
// level <debug|info|warn|error>
|
||||
// logger_name <name|access>
|
||||
// skip_handshake
|
||||
// }
|
||||
// verbose_logs
|
||||
//
|
||||
// # request manipulation
|
||||
@@ -709,49 +703,6 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
h.StreamCloseDelay = caddy.Duration(dur)
|
||||
}
|
||||
|
||||
case "stream_detached":
|
||||
if d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.StreamDetached = true
|
||||
|
||||
case "stream_logs":
|
||||
if d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
if h.StreamLogs == nil {
|
||||
h.StreamLogs = new(StreamLogs)
|
||||
}
|
||||
|
||||
nesting := d.Nesting()
|
||||
for d.NextBlock(nesting) {
|
||||
switch d.Val() {
|
||||
case "level":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.StreamLogs.Level = d.Val()
|
||||
if d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
case "logger_name":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.StreamLogs.LoggerName = d.Val()
|
||||
if d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
case "skip_handshake":
|
||||
if d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.StreamLogs.SkipHandshake = true
|
||||
default:
|
||||
return d.Errf("unrecognized stream_logs option: %s", d.Val())
|
||||
}
|
||||
}
|
||||
|
||||
case "trusted_proxies":
|
||||
for d.NextArg() {
|
||||
if d.Val() == "private_ranges" {
|
||||
|
||||
@@ -80,7 +80,7 @@ func (h CopyResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
|
||||
hrc.isFinalized = true
|
||||
|
||||
// write the response
|
||||
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger, hrc.upstreamAddr)
|
||||
return hrc.handler.finalizeResponse(rw, req, hrc.response, repl, hrc.start, hrc.logger)
|
||||
}
|
||||
|
||||
// CopyResponseHeadersHandler is a special HTTP handler which may
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
type extendedConnectCapture struct {
|
||||
method string
|
||||
headers http.Header
|
||||
body []byte
|
||||
extendedBodyPresent bool
|
||||
extendedConnectBody []byte
|
||||
}
|
||||
|
||||
type extendedConnectCaptureTransport struct {
|
||||
mu sync.Mutex
|
||||
capture extendedConnectCapture
|
||||
}
|
||||
|
||||
func (tr *extendedConnectCaptureTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := extendedConnectCapture{
|
||||
method: req.Method,
|
||||
headers: req.Header.Clone(),
|
||||
body: body,
|
||||
}
|
||||
if rc, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
|
||||
c.extendedBodyPresent = true
|
||||
c.extendedConnectBody, err = io.ReadAll(rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = rc.Close()
|
||||
}
|
||||
|
||||
tr.mu.Lock()
|
||||
tr.capture = c
|
||||
tr.mu.Unlock()
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
Request: req,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tr *extendedConnectCaptureTransport) Snapshot() extendedConnectCapture {
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
return tr.capture
|
||||
}
|
||||
|
||||
func TestServeHTTPRewritesExtendedConnectWebsocketRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protoMajor int
|
||||
proto string
|
||||
headers map[string]string
|
||||
}{
|
||||
{
|
||||
name: "h2 extended connect",
|
||||
protoMajor: 2,
|
||||
proto: "HTTP/2.0",
|
||||
headers: map[string]string{
|
||||
":protocol": "websocket",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "h3 extended connect",
|
||||
protoMajor: 3,
|
||||
proto: "websocket",
|
||||
headers: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const payload = "extended-connect-body"
|
||||
|
||||
transport := new(extendedConnectCaptureTransport)
|
||||
h := &Handler{
|
||||
logger: zap.NewNop(),
|
||||
Transport: transport,
|
||||
Upstreams: UpstreamPool{
|
||||
&Upstream{Host: new(Host), Dial: "127.0.0.1:8443"},
|
||||
},
|
||||
LoadBalancing: &LoadBalancing{
|
||||
SelectionPolicy: &RoundRobinSelection{},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodConnect, "http://example.test/upgrade", strings.NewReader(payload))
|
||||
req.ProtoMajor = tc.protoMajor
|
||||
req.Proto = tc.proto
|
||||
for key, value := range tc.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
req = prepareTestRequest(req)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
err := h.ServeHTTP(rr, req, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("ServeHTTP() error = %v", err)
|
||||
}
|
||||
|
||||
captured := transport.Snapshot()
|
||||
if captured.method != http.MethodGet {
|
||||
t.Fatalf("upstream method = %s, want %s", captured.method, http.MethodGet)
|
||||
}
|
||||
if got := captured.headers.Get("Upgrade"); !strings.EqualFold(got, "websocket") {
|
||||
t.Fatalf("Upgrade header = %q, want websocket", got)
|
||||
}
|
||||
if got := captured.headers.Get("Connection"); !strings.EqualFold(got, "Upgrade") {
|
||||
t.Fatalf("Connection header = %q, want Upgrade", got)
|
||||
}
|
||||
if got := captured.headers.Get(":protocol"); got != "" {
|
||||
t.Fatalf(":protocol header should be removed, got %q", got)
|
||||
}
|
||||
if len(captured.body) != 0 {
|
||||
t.Fatalf("upstream request body length = %d, want 0", len(captured.body))
|
||||
}
|
||||
if !captured.extendedBodyPresent {
|
||||
t.Fatal("extended_connect_websocket_body variable missing from request context")
|
||||
}
|
||||
if string(captured.extendedConnectBody) != payload {
|
||||
t.Fatalf("extended_connect_websocket_body = %q, want %q", string(captured.extendedConnectBody), payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -16,10 +16,6 @@ import (
|
||||
var reverseProxyMetrics = struct {
|
||||
once sync.Once
|
||||
upstreamsHealthy *prometheus.GaugeVec
|
||||
streamsActive *prometheus.GaugeVec
|
||||
streamsTotal *prometheus.CounterVec
|
||||
streamDuration *prometheus.HistogramVec
|
||||
streamBytes *prometheus.CounterVec
|
||||
logger *zap.Logger
|
||||
}{}
|
||||
|
||||
@@ -27,8 +23,6 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
|
||||
const ns, sub = "caddy", "reverse_proxy"
|
||||
|
||||
upstreamsLabels := []string{"upstream"}
|
||||
streamResultLabels := []string{"upstream", "result"}
|
||||
streamBytesLabels := []string{"upstream", "direction"}
|
||||
reverseProxyMetrics.once.Do(func() {
|
||||
reverseProxyMetrics.upstreamsHealthy = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: ns,
|
||||
@@ -36,31 +30,6 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
|
||||
Name: "upstreams_healthy",
|
||||
Help: "Health status of reverse proxy upstreams.",
|
||||
}, upstreamsLabels)
|
||||
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: ns,
|
||||
Subsystem: sub,
|
||||
Name: "streams_active",
|
||||
Help: "Number of currently active upgraded reverse proxy streams.",
|
||||
}, upstreamsLabels)
|
||||
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: ns,
|
||||
Subsystem: sub,
|
||||
Name: "streams_total",
|
||||
Help: "Total number of upgraded reverse proxy streams by close result.",
|
||||
}, streamResultLabels)
|
||||
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: ns,
|
||||
Subsystem: sub,
|
||||
Name: "stream_duration_seconds",
|
||||
Help: "Duration of upgraded reverse proxy streams by close result.",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
}, streamResultLabels)
|
||||
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: ns,
|
||||
Subsystem: sub,
|
||||
Name: "stream_bytes_total",
|
||||
Help: "Total bytes proxied across upgraded reverse proxy streams.",
|
||||
}, streamBytesLabels)
|
||||
})
|
||||
|
||||
// duplicate registration could happen if multiple sites with reverse proxy are configured; so ignore the error because
|
||||
@@ -73,58 +42,10 @@ func initReverseProxyMetrics(handler *Handler, registry *prometheus.Registry) {
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
if err := registry.Register(reverseProxyMetrics.streamsActive); err != nil &&
|
||||
!errors.Is(err, prometheus.AlreadyRegisteredError{
|
||||
ExistingCollector: reverseProxyMetrics.streamsActive,
|
||||
NewCollector: reverseProxyMetrics.streamsActive,
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
if err := registry.Register(reverseProxyMetrics.streamsTotal); err != nil &&
|
||||
!errors.Is(err, prometheus.AlreadyRegisteredError{
|
||||
ExistingCollector: reverseProxyMetrics.streamsTotal,
|
||||
NewCollector: reverseProxyMetrics.streamsTotal,
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
if err := registry.Register(reverseProxyMetrics.streamDuration); err != nil &&
|
||||
!errors.Is(err, prometheus.AlreadyRegisteredError{
|
||||
ExistingCollector: reverseProxyMetrics.streamDuration,
|
||||
NewCollector: reverseProxyMetrics.streamDuration,
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
if err := registry.Register(reverseProxyMetrics.streamBytes); err != nil &&
|
||||
!errors.Is(err, prometheus.AlreadyRegisteredError{
|
||||
ExistingCollector: reverseProxyMetrics.streamBytes,
|
||||
NewCollector: reverseProxyMetrics.streamBytes,
|
||||
}) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
reverseProxyMetrics.logger = handler.logger.Named("reverse_proxy.metrics")
|
||||
}
|
||||
|
||||
func trackActiveStream(upstream string) func(result string, duration time.Duration, toBackend, fromBackend int64) {
|
||||
labels := prometheus.Labels{"upstream": upstream}
|
||||
reverseProxyMetrics.streamsActive.With(labels).Inc()
|
||||
|
||||
var once sync.Once
|
||||
return func(result string, duration time.Duration, toBackend, fromBackend int64) {
|
||||
once.Do(func() {
|
||||
reverseProxyMetrics.streamsActive.With(labels).Dec()
|
||||
reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, result).Inc()
|
||||
reverseProxyMetrics.streamDuration.WithLabelValues(upstream, result).Observe(duration.Seconds())
|
||||
if toBackend > 0 {
|
||||
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream").Add(float64(toBackend))
|
||||
}
|
||||
if fromBackend > 0 {
|
||||
reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream").Add(float64(fromBackend))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type metricsUpstreamsHealthyUpdater struct {
|
||||
handler *Handler
|
||||
}
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
)
|
||||
|
||||
func TestTrackActiveStreamRecordsLifecycleAndBytes(t *testing.T) {
|
||||
const upstream = "127.0.0.1:7443"
|
||||
|
||||
// Use fresh metric vectors for deterministic assertions in this unit test.
|
||||
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
|
||||
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
|
||||
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
|
||||
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
|
||||
|
||||
finish := trackActiveStream(upstream)
|
||||
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 1 {
|
||||
t.Fatalf("active streams = %v, want 1", got)
|
||||
}
|
||||
|
||||
finish("closed", 150*time.Millisecond, 1234, 4321)
|
||||
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsActive.WithLabelValues(upstream)); got != 0 {
|
||||
t.Fatalf("active streams = %v, want 0", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "closed")); got != 1 {
|
||||
t.Fatalf("streams_total closed = %v, want 1", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 1234 {
|
||||
t.Fatalf("bytes to_upstream = %v, want 1234", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 4321 {
|
||||
t.Fatalf("bytes from_upstream = %v, want 4321", got)
|
||||
}
|
||||
|
||||
// A second finish call should be ignored by the once guard.
|
||||
finish("error", 1*time.Second, 111, 222)
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "error")); got != 0 {
|
||||
t.Fatalf("streams_total error = %v, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrackActiveStreamDoesNotCountZeroBytes(t *testing.T) {
|
||||
const upstream = "127.0.0.1:9000"
|
||||
|
||||
reverseProxyMetrics.streamsActive = prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"upstream"})
|
||||
reverseProxyMetrics.streamsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "result"})
|
||||
reverseProxyMetrics.streamDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{}, []string{"upstream", "result"})
|
||||
reverseProxyMetrics.streamBytes = prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"upstream", "direction"})
|
||||
|
||||
trackActiveStream(upstream)("timeout", 250*time.Millisecond, 0, 0)
|
||||
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "to_upstream")); got != 0 {
|
||||
t.Fatalf("bytes to_upstream = %v, want 0", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamBytes.WithLabelValues(upstream, "from_upstream")); got != 0 {
|
||||
t.Fatalf("bytes from_upstream = %v, want 0", got)
|
||||
}
|
||||
if got := testutil.ToFloat64(reverseProxyMetrics.streamsTotal.WithLabelValues(upstream, "timeout")); got != 1 {
|
||||
t.Fatalf("streams_total timeout = %v, want 1", got)
|
||||
}
|
||||
}
|
||||
@@ -186,22 +186,6 @@ type Handler struct {
|
||||
// by the previous config closing. Default: no delay.
|
||||
StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"`
|
||||
|
||||
// If true, upgraded connections such as WebSockets are detached from
|
||||
// the handler and retained across config reloads when their upstream
|
||||
// still exists in the new config. Connections using upstreams that are
|
||||
// removed are closed during cleanup. By default this is false, preserving
|
||||
// legacy behavior where upgraded connections are closed on reload
|
||||
// (optionally delayed by stream_close_delay).
|
||||
// Only http1.1 websocket connections are affected, websockets for h2/h3
|
||||
// are not affected. If true, bytes transferred for http1.1 in the access
|
||||
// logs will be zero but those stats can be found in the stream logs for
|
||||
// http1/2/3 regardless if this is enabled.
|
||||
StreamDetached bool `json:"stream_detached,omitempty"`
|
||||
|
||||
// Controls logging behavior for upgraded stream lifecycle events.
|
||||
// If omitted, defaults are used (level=DEBUG, logger_name="http.handlers.reverse_proxy.stream").
|
||||
StreamLogs *StreamLogs `json:"stream_logs,omitempty"`
|
||||
|
||||
// If configured, rewrites the copy of the upstream request.
|
||||
// Allows changing the request method and URI (path and query).
|
||||
// Since the rewrite is applied to the copy, it does not persist
|
||||
@@ -256,16 +240,14 @@ type Handler struct {
|
||||
// Holds the handle_response Caddyfile tokens while adapting
|
||||
handleResponseSegments []*caddyfile.Dispenser
|
||||
|
||||
// Tracks hijacked/upgraded connections (WebSocket etc.) so they can be
|
||||
// closed when their upstream is removed from the config.
|
||||
tunnelTracker *tunnelTracker
|
||||
// Stores upgraded requests (hijacked connections) for proper cleanup
|
||||
connections map[io.ReadWriteCloser]openConnection
|
||||
connectionsCloseTimer *time.Timer
|
||||
connectionsMu *sync.Mutex
|
||||
|
||||
ctx caddy.Context
|
||||
logger *zap.Logger
|
||||
events *caddyevents.App
|
||||
|
||||
streamLogLevel zapcore.Level
|
||||
streamLogLoggerName string
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
@@ -285,25 +267,8 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||
h.events = eventAppIface.(*caddyevents.App)
|
||||
h.ctx = ctx
|
||||
h.logger = ctx.Logger()
|
||||
h.tunnelTracker = newTunnelTracker(h.logger, time.Duration(h.StreamCloseDelay))
|
||||
h.streamLogLevel = defaultStreamLogLevel
|
||||
h.streamLogLoggerName = defaultStreamLoggerName
|
||||
if h.StreamLogs != nil {
|
||||
if h.StreamLogs.Level != "" {
|
||||
lvl, err := zapcore.ParseLevel(strings.ToLower(strings.TrimSpace(h.StreamLogs.Level)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid stream_logs.level %q: %w", h.StreamLogs.Level, err)
|
||||
}
|
||||
h.streamLogLevel = lvl
|
||||
}
|
||||
if name := strings.TrimSpace(h.StreamLogs.LoggerName); name != "" {
|
||||
h.streamLogLoggerName = name
|
||||
}
|
||||
}
|
||||
|
||||
if h.StreamDetached {
|
||||
registerDetachedTunnelTrackers(h.tunnelTracker)
|
||||
}
|
||||
h.connections = make(map[io.ReadWriteCloser]openConnection)
|
||||
h.connectionsMu = new(sync.Mutex)
|
||||
|
||||
// warn about unsafe buffering config
|
||||
if h.RequestBuffers == -1 || h.ResponseBuffers == -1 {
|
||||
@@ -472,85 +437,15 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h Handler) streamLogsSkipHandshake() bool {
|
||||
return h.StreamLogs != nil && h.StreamLogs.SkipHandshake
|
||||
}
|
||||
|
||||
func (h Handler) streamLoggerForRequest(req *http.Request) *zap.Logger {
|
||||
name := strings.TrimSpace(h.streamLogLoggerName)
|
||||
if name == "" {
|
||||
name = defaultStreamLoggerName
|
||||
}
|
||||
|
||||
if name == streamLoggerNameUseAccess {
|
||||
logger := caddy.Log().Named(defaultAccessLoggerBase)
|
||||
names := caddyhttp.GetVar(req.Context(), caddyhttp.AccessLoggerNameVarKey)
|
||||
namesSlice, ok := names.([]any)
|
||||
if !ok {
|
||||
return logger
|
||||
}
|
||||
for _, v := range namesSlice {
|
||||
name, ok := v.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name == "" {
|
||||
return logger
|
||||
}
|
||||
return logger.Named(name)
|
||||
}
|
||||
return logger
|
||||
}
|
||||
|
||||
return caddy.Log().Named(name)
|
||||
}
|
||||
|
||||
var (
|
||||
detachedTunnelTrackers = make(map[*tunnelTracker]struct{})
|
||||
detachedTunnelTrackersMu sync.Mutex
|
||||
)
|
||||
|
||||
func registerDetachedTunnelTrackers(ts *tunnelTracker) {
|
||||
detachedTunnelTrackersMu.Lock()
|
||||
defer detachedTunnelTrackersMu.Unlock()
|
||||
detachedTunnelTrackers[ts] = struct{}{}
|
||||
}
|
||||
|
||||
func notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream string, self *tunnelTracker) error {
|
||||
detachedTunnelTrackersMu.Lock()
|
||||
defer detachedTunnelTrackersMu.Unlock()
|
||||
|
||||
var err error
|
||||
for tunnel := range detachedTunnelTrackers {
|
||||
if closeErr := tunnel.closeConnectionsForUpstream(upstream); closeErr != nil && tunnel == self && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func unregisterDetachedTunnelTrackers(ts *tunnelTracker) {
|
||||
detachedTunnelTrackersMu.Lock()
|
||||
defer detachedTunnelTrackersMu.Unlock()
|
||||
delete(detachedTunnelTrackers, ts)
|
||||
}
|
||||
|
||||
// Cleanup cleans up the resources made by h.
|
||||
func (h *Handler) Cleanup() error {
|
||||
// even if StreamDetached is true, extended connect websockets may still be running
|
||||
err := h.tunnelTracker.cleanupAttachedConnections()
|
||||
err := h.cleanupConnections()
|
||||
|
||||
// remove hosts from our config from the pool
|
||||
for _, upstream := range h.Upstreams {
|
||||
// hosts.Delete returns deleted=true when the ref count reaches zero,
|
||||
// meaning no other active config references this upstream. In that
|
||||
// case close any tunnels proxying to it; otherwise let them survive
|
||||
// to their natural end since the upstream is still in use.
|
||||
deleted, _ := hosts.Delete(upstream.String())
|
||||
if deleted {
|
||||
if closeErr := notifyDetachedTunnelTrackersOfUpstreamRemoval(upstream.String(), h.tunnelTracker); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}
|
||||
_, _ = hosts.Delete(upstream.String())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1242,11 +1137,10 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
|
||||
// we use the original request here, so that any routes from 'next'
|
||||
// see the original request rather than the proxy cloned request.
|
||||
hrc := &handleResponseContext{
|
||||
handler: h,
|
||||
response: res,
|
||||
start: start,
|
||||
logger: logger,
|
||||
upstreamAddr: di.Upstream.String(),
|
||||
handler: h,
|
||||
response: res,
|
||||
start: start,
|
||||
logger: logger,
|
||||
}
|
||||
ctx := origReq.Context()
|
||||
ctx = context.WithValue(ctx, proxyHandleResponseContextCtxKey, hrc)
|
||||
@@ -1276,7 +1170,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
|
||||
}
|
||||
|
||||
// copy the response body and headers back to the upstream client
|
||||
return h.finalizeResponse(rw, req, res, repl, start, logger, di.Upstream.String())
|
||||
return h.finalizeResponse(rw, req, res, repl, start, logger)
|
||||
}
|
||||
|
||||
// finalizeResponse prepares and copies the response.
|
||||
@@ -1287,11 +1181,12 @@ func (h *Handler) finalizeResponse(
|
||||
repl *caddy.Replacer,
|
||||
start time.Time,
|
||||
logger *zap.Logger,
|
||||
upstreamAddr string,
|
||||
) error {
|
||||
// deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode == http.StatusSwitchingProtocols {
|
||||
h.handleUpgradeResponse(logger, rw, req, res, upstreamAddr)
|
||||
var wg sync.WaitGroup
|
||||
h.handleUpgradeResponse(logger, &wg, rw, req, res)
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1898,22 +1793,6 @@ func (brc bodyReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StreamLogs controls logging for upgraded stream lifecycle events.
|
||||
type StreamLogs struct {
|
||||
// The minimum level at which stream lifecycle events are logged.
|
||||
// Supported values are debug, info, warn, and error. Default: debug.
|
||||
Level string `json:"level,omitempty"`
|
||||
|
||||
// Logger name for stream lifecycle logs. Default: "http.handlers.reverse_proxy.stream".
|
||||
// Special value "access" uses the access logger namespace and, if set,
|
||||
// respects the first value in access_logger_names/log_name for the request.
|
||||
LoggerName string `json:"logger_name,omitempty"`
|
||||
|
||||
// If true, suppresses the access log entry normally emitted when an
|
||||
// upgraded stream handshake completes and the request unwinds.
|
||||
SkipHandshake bool `json:"skip_handshake,omitempty"`
|
||||
}
|
||||
|
||||
// bufPool is used for buffering requests and responses.
|
||||
var bufPool = sync.Pool{
|
||||
New: func() any {
|
||||
@@ -1946,9 +1825,6 @@ type handleResponseContext struct {
|
||||
// i.e. copied and closed, to make sure that it doesn't
|
||||
// happen twice.
|
||||
isFinalized bool
|
||||
|
||||
// upstreamAddr is the selected upstream address for this request.
|
||||
upstreamAddr string
|
||||
}
|
||||
|
||||
// proxyHandleResponseContextCtxKey is the context key for the active proxy handler
|
||||
@@ -1959,13 +1835,6 @@ const proxyHandleResponseContextCtxKey caddy.CtxKey = "reverse_proxy_handle_resp
|
||||
// errNoUpstream occurs when there are no upstream available.
|
||||
var errNoUpstream = fmt.Errorf("no upstreams available")
|
||||
|
||||
const (
|
||||
defaultStreamLogLevel = zapcore.DebugLevel
|
||||
defaultStreamLoggerName = "http.handlers.reverse_proxy.stream"
|
||||
streamLoggerNameUseAccess = "access"
|
||||
defaultAccessLoggerBase = "http.log.access"
|
||||
)
|
||||
|
||||
// Interface guards
|
||||
var (
|
||||
_ caddy.Provisioner = (*Handler)(nil)
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"io"
|
||||
weakrand "math/rand/v2"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -36,16 +35,15 @@ import (
|
||||
"go.uber.org/zap/zapcore"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
type extendedConnectReadWriteCloser struct {
|
||||
type h2ReadWriteCloser struct {
|
||||
io.ReadCloser
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (rwc extendedConnectReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
n, err = rwc.ResponseWriter.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -59,7 +57,7 @@ func (rwc extendedConnectReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response, upstreamAddr string) {
|
||||
func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||
reqUpType := upgradeType(req.Header)
|
||||
resUpType := upgradeType(res.Header)
|
||||
|
||||
@@ -92,37 +90,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
normalizeWebsocketHeaders(rw.Header())
|
||||
|
||||
// Capture all h fields needed by the tunnel now, so that the Handler (h)
|
||||
// is not referenced after this function returns (for HTTP/1.1 hijacked
|
||||
// connections the tunnel runs in a detached goroutine).
|
||||
tunnel := h.tunnelTracker
|
||||
bufferSize := h.StreamBufferSize
|
||||
streamTimeout := time.Duration(h.StreamTimeout)
|
||||
|
||||
if h.StreamDetached {
|
||||
// the return value should be true as it's not hijacked yet,
|
||||
// but some middleware may wrap response writers incorrectly
|
||||
if !caddyhttp.DetachResponseWriterAfterHijack(rw, true) {
|
||||
if c := logger.Check(zap.DebugLevel, "detaching connection failed"); c != nil {
|
||||
c.Write(zap.String("tip", "check if your response writers have an Unwrap method or if already hijacked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
conn io.ReadWriteCloser
|
||||
brw *bufio.ReadWriter
|
||||
detached = h.StreamDetached
|
||||
conn io.ReadWriteCloser
|
||||
brw *bufio.ReadWriter
|
||||
)
|
||||
// websocket over http2 or http3 if extended connect is enabled,
|
||||
// assuming backend doesn't support this, the request will be
|
||||
// modified to http1.1 upgrade
|
||||
// TODO: once we can reliably detect backend support this, it can
|
||||
// be removed for those backends
|
||||
// websocket over http2 or http3 if extended connect is enabled, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
|
||||
// TODO: once we can reliably detect backend support this, it can be removed for those backends
|
||||
if body, ok := caddyhttp.GetVar(req.Context(), "extended_connect_websocket_body").(io.ReadCloser); ok {
|
||||
// websocket over extended connect can't be detached. rw and req.Body
|
||||
// are only valid while the handler goroutine is running
|
||||
detached = false
|
||||
req.Body = body
|
||||
rw.Header().Del("Upgrade")
|
||||
rw.Header().Del("Connection")
|
||||
@@ -130,18 +104,18 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
|
||||
c.Write(zap.Int("http_version", req.ProtoMajor))
|
||||
c.Write(zap.Int("http_version", 2))
|
||||
}
|
||||
|
||||
//nolint:bodyclose
|
||||
flushErr := http.NewResponseController(rw).Flush()
|
||||
if flushErr != nil {
|
||||
if c := h.logger.Check(zap.ErrorLevel, "failed to flush extended_connect websocket response"); c != nil {
|
||||
if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
|
||||
c.Write(zap.Error(flushErr))
|
||||
}
|
||||
return
|
||||
}
|
||||
conn = extendedConnectReadWriteCloser{req.Body, rw}
|
||||
conn = h2ReadWriteCloser{req.Body, rw}
|
||||
// bufio is not needed, use minimal buffer
|
||||
brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
|
||||
} else {
|
||||
@@ -169,6 +143,27 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
|
||||
}
|
||||
}
|
||||
|
||||
// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
|
||||
backConnCloseCh := make(chan struct{})
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
backConn.Close()
|
||||
}()
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
conn.Close()
|
||||
if c := logger.Check(zapcore.DebugLevel, "connection closed"); c != nil {
|
||||
c.Write(zap.Duration("duration", time.Since(start)))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
if c := logger.Check(zapcore.DebugLevel, "response flush"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
@@ -189,12 +184,13 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
|
||||
}
|
||||
}
|
||||
|
||||
// Register both connections with the tunnel tracker. We also try to
|
||||
// gracefully close connections we recognize as websockets. We need to make
|
||||
// sure the client connection messages (i.e. to upstream) are masked, so we
|
||||
// need to know whether the connection is considered the server or the
|
||||
// client side of the proxy. Note that gracefulClose must not capture h,
|
||||
// since the tunnel may outlive the handler instance.
|
||||
// Ensure the hijacked client connection, and the new connection established
|
||||
// with the backend, are both closed in the event of a server shutdown. This
|
||||
// is done by registering them. We also try to gracefully close connections
|
||||
// we recognize as websockets.
|
||||
// We need to make sure the client connection messages (i.e. to upstream)
|
||||
// are masked, so we need to know whether the connection is considered the
|
||||
// server or the client side of the proxy.
|
||||
gracefulClose := func(conn io.ReadWriteCloser, isClient bool) func() error {
|
||||
if isWebsocket(req) {
|
||||
return func() error {
|
||||
@@ -203,147 +199,43 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrit
|
||||
}
|
||||
return nil
|
||||
}
|
||||
deleteFrontConn := tunnel.registerConnection(conn, gracefulClose(conn, false), detached, upstreamAddr)
|
||||
deleteBackConn := tunnel.registerConnection(backConn, gracefulClose(backConn, true), detached, upstreamAddr)
|
||||
if h.streamLogsSkipHandshake() {
|
||||
caddyhttp.SetVar(req.Context(), caddyhttp.LogSkipVar, true)
|
||||
}
|
||||
repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
repl.Set("http.reverse_proxy.upgraded", true)
|
||||
streamUUID, _ := repl.GetString("http.request.uuid")
|
||||
streamFields := makeStreamLogFields(streamUUID)
|
||||
streamLogger := h.streamLoggerForRequest(req)
|
||||
streamLevel := h.streamLogLevel
|
||||
finishMetrics := trackActiveStream(upstreamAddr)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
if !detached {
|
||||
handleUpgradeTunnel(
|
||||
streamLogger,
|
||||
streamLevel,
|
||||
conn,
|
||||
backConn,
|
||||
deleteFrontConn,
|
||||
deleteBackConn,
|
||||
bufferSize,
|
||||
streamTimeout,
|
||||
start,
|
||||
finishMetrics,
|
||||
streamFields,
|
||||
)
|
||||
} else {
|
||||
// start a new goroutine
|
||||
go handleUpgradeTunnel(
|
||||
streamLogger,
|
||||
streamLevel,
|
||||
conn,
|
||||
backConn,
|
||||
deleteFrontConn,
|
||||
deleteBackConn,
|
||||
bufferSize,
|
||||
streamTimeout,
|
||||
start,
|
||||
finishMetrics,
|
||||
streamFields,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpgradeTunnel returns when transfer is done.
|
||||
func handleUpgradeTunnel(
|
||||
streamLogger *zap.Logger,
|
||||
streamLevel zapcore.Level,
|
||||
conn io.ReadWriteCloser,
|
||||
backConn io.ReadWriteCloser,
|
||||
deleteFrontConn func(),
|
||||
deleteBackConn func(),
|
||||
bufferSize int,
|
||||
streamTimeout time.Duration,
|
||||
start time.Time,
|
||||
finishMetrics func(result string, duration time.Duration, toBackend int64, fromBackend int64),
|
||||
streamFields []zap.Field,
|
||||
) {
|
||||
defer deleteBackConn()
|
||||
deleteFrontConn := h.registerConnection(conn, gracefulClose(conn, false))
|
||||
deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn, true))
|
||||
defer deleteFrontConn()
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
toBackend int64
|
||||
fromBackend int64
|
||||
result string
|
||||
)
|
||||
defer deleteBackConn()
|
||||
|
||||
// when a stream timeout is encountered, no error will be read from errc
|
||||
// a buffer size of 2 will allow both the read and write goroutines to
|
||||
// send the error and exit
|
||||
// see: https://github.com/caddyserver/caddy/issues/7418
|
||||
errc := make(chan error, 2)
|
||||
spc := switchProtocolCopier{
|
||||
user: conn,
|
||||
backend: backConn,
|
||||
wg: &wg,
|
||||
bufferSize: bufferSize,
|
||||
sent: &toBackend,
|
||||
received: &fromBackend,
|
||||
wg: wg,
|
||||
bufferSize: h.StreamBufferSize,
|
||||
}
|
||||
wg.Add(2)
|
||||
|
||||
// setup the timeout if requested
|
||||
var timeoutc <-chan time.Time
|
||||
if streamTimeout > 0 {
|
||||
timer := time.NewTimer(streamTimeout)
|
||||
if h.StreamTimeout > 0 {
|
||||
timer := time.NewTimer(time.Duration(h.StreamTimeout))
|
||||
defer timer.Stop()
|
||||
timeoutc = timer.C
|
||||
}
|
||||
|
||||
// when a stream timeout is encountered, no error will be read from errc
|
||||
// a buffer size of 2 will allow both the read and write goroutines to send the error and exit
|
||||
// see: https://github.com/caddyserver/caddy/issues/7418
|
||||
errc := make(chan error, 2)
|
||||
wg.Add(2)
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
select {
|
||||
case err := <-errc:
|
||||
result = classifyStreamResult(err)
|
||||
if c := streamLogger.Check(streamLevel, "streaming error"); c != nil {
|
||||
if c := logger.Check(zapcore.DebugLevel, "streaming error"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
case t := <-timeoutc:
|
||||
result = "timeout"
|
||||
if c := streamLogger.Check(streamLevel, "stream timed out"); c != nil {
|
||||
c.Write(zap.Time("timeout", t))
|
||||
case time := <-timeoutc:
|
||||
if c := logger.Check(zapcore.DebugLevel, "stream timed out"); c != nil {
|
||||
c.Write(zap.Time("timeout", time))
|
||||
}
|
||||
}
|
||||
|
||||
// Close both ends to unblock the still-running copy goroutine,
|
||||
// then wait for it so byte counts are final before metrics/logging.
|
||||
conn.Close()
|
||||
backConn.Close()
|
||||
wg.Wait()
|
||||
|
||||
finishMetrics(result, time.Since(start), toBackend, fromBackend)
|
||||
if c := streamLogger.Check(streamLevel, "connection closed"); c != nil {
|
||||
fields := append([]zap.Field{}, streamFields...)
|
||||
fields = append(fields,
|
||||
zap.Duration("duration", time.Since(start)),
|
||||
zap.Int64("bytes_to_backend", toBackend),
|
||||
zap.Int64("bytes_from_backend", fromBackend),
|
||||
)
|
||||
c.Write(fields...)
|
||||
}
|
||||
}
|
||||
|
||||
func classifyStreamResult(err error) string {
|
||||
if err == nil ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, net.ErrClosed) ||
|
||||
errors.Is(err, context.Canceled) {
|
||||
return "closed"
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
|
||||
func makeStreamLogFields(streamUUID string) []zap.Field {
|
||||
fields := make([]zap.Field, 0, 1)
|
||||
if streamUUID != "" {
|
||||
fields = append(fields, zap.String("uuid", streamUUID))
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// flushInterval returns the p.FlushInterval value, conditionally
|
||||
@@ -483,101 +375,75 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte, logger *za
|
||||
}
|
||||
}
|
||||
|
||||
// openConnection maps an open connection to an optional function for graceful
|
||||
// close and records which upstream address the connection is proxying to.
|
||||
// Also tracks whether the connection is detached, which means it should only be
|
||||
// closed when the upstream is removed from the config, not on every reload.
|
||||
type openConnection struct {
|
||||
conn io.ReadWriteCloser
|
||||
gracefulClose func() error
|
||||
detached bool
|
||||
upstream string
|
||||
}
|
||||
|
||||
// tunnelTracker tracks hijacked/upgraded connections for selective cleanup.
|
||||
// This exists to detach the lifecycle of streaming connections from the proxy
|
||||
// Handler and config, since we typically want them to survive past config reloads.
|
||||
// It also allows for selective connection cleanup based on their attachment status.
|
||||
type tunnelTracker struct {
|
||||
connections map[io.ReadWriteCloser]openConnection
|
||||
closeTimer *time.Timer
|
||||
closeDelay time.Duration
|
||||
stopped bool
|
||||
mu sync.Mutex
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func newTunnelTracker(logger *zap.Logger, closeDelay time.Duration) *tunnelTracker {
|
||||
return &tunnelTracker{
|
||||
connections: make(map[io.ReadWriteCloser]openConnection),
|
||||
closeDelay: closeDelay,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// registerConnection stores conn in the tracking map. The caller must invoke
|
||||
// the returned del func when the connection is done.
|
||||
func (ts *tunnelTracker) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error, detached bool, upstream string) (del func()) {
|
||||
ts.mu.Lock()
|
||||
ts.connections[conn] = openConnection{conn, gracefulClose, detached, upstream}
|
||||
ts.mu.Unlock()
|
||||
// registerConnection holds onto conn so it can be closed in the event
|
||||
// of a server shutdown. This is useful because hijacked connections or
|
||||
// connections dialed to backends don't close when server is shut down.
|
||||
// The caller should call the returned delete() function when the
|
||||
// connection is done to remove it from memory.
|
||||
func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) {
|
||||
h.connectionsMu.Lock()
|
||||
h.connections[conn] = openConnection{conn, gracefulClose}
|
||||
h.connectionsMu.Unlock()
|
||||
return func() {
|
||||
ts.mu.Lock()
|
||||
delete(ts.connections, conn)
|
||||
if len(ts.connections) == 0 && ts.stopped {
|
||||
unregisterDetachedTunnelTrackers(ts)
|
||||
if ts.closeTimer != nil {
|
||||
if ts.closeTimer.Stop() {
|
||||
ts.logger.Debug("stopped streaming connections close timer - all connections are already closed")
|
||||
}
|
||||
ts.closeTimer = nil
|
||||
h.connectionsMu.Lock()
|
||||
delete(h.connections, conn)
|
||||
// if there is no connection left before the connections close timer fires
|
||||
if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
|
||||
// we release the timer that holds the reference to Handler
|
||||
if (*h.connectionsCloseTimer).Stop() {
|
||||
h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
|
||||
}
|
||||
h.connectionsCloseTimer = nil
|
||||
}
|
||||
ts.mu.Unlock()
|
||||
h.connectionsMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// closeAttachedConnections closes all tracked attached connections.
|
||||
func (ts *tunnelTracker) closeAttachedConnections() error {
|
||||
// closeConnections immediately closes all hijacked connections (both to client and backend).
|
||||
func (h *Handler) closeConnections() error {
|
||||
var err error
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.stopped = true
|
||||
for _, oc := range ts.connections {
|
||||
// detached connections are only closed when the upstream is gone from the config
|
||||
if oc.detached {
|
||||
continue
|
||||
}
|
||||
h.connectionsMu.Lock()
|
||||
defer h.connectionsMu.Unlock()
|
||||
|
||||
for _, oc := range h.connections {
|
||||
if oc.gracefulClose != nil {
|
||||
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
|
||||
// this is potentially blocking while we have the lock on the connections
|
||||
// map, but that should be OK since the server has in theory shut down
|
||||
// and we are no longer using the connections map
|
||||
gracefulErr := oc.gracefulClose()
|
||||
if gracefulErr != nil && err == nil {
|
||||
err = gracefulErr
|
||||
}
|
||||
}
|
||||
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
|
||||
closeErr := oc.conn.Close()
|
||||
if closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanupAttachedConnections closes upgraded attached connections.
|
||||
// Depending on closeDelay it does that either immediately or after a timer.
|
||||
func (ts *tunnelTracker) cleanupAttachedConnections() error {
|
||||
if ts.closeDelay == 0 {
|
||||
return ts.closeAttachedConnections()
|
||||
// cleanupConnections closes hijacked connections.
|
||||
// Depending on the value of StreamCloseDelay it does that either immediately
|
||||
// or sets up a timer that will do that later.
|
||||
func (h *Handler) cleanupConnections() error {
|
||||
if h.StreamCloseDelay == 0 {
|
||||
return h.closeConnections()
|
||||
}
|
||||
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if len(ts.connections) > 0 {
|
||||
delay := ts.closeDelay
|
||||
ts.closeTimer = time.AfterFunc(delay, func() {
|
||||
if c := ts.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
|
||||
h.connectionsMu.Lock()
|
||||
defer h.connectionsMu.Unlock()
|
||||
// the handler is shut down, no new connection can appear,
|
||||
// so we can skip setting up the timer when there are no connections
|
||||
if len(h.connections) > 0 {
|
||||
delay := time.Duration(h.StreamCloseDelay)
|
||||
h.connectionsCloseTimer = time.AfterFunc(delay, func() {
|
||||
if c := h.logger.Check(zapcore.DebugLevel, "closing streaming connections after delay"); c != nil {
|
||||
c.Write(zap.Duration("delay", delay))
|
||||
}
|
||||
err := ts.closeAttachedConnections()
|
||||
err := h.closeConnections()
|
||||
if err != nil {
|
||||
if c := ts.logger.Check(zapcore.ErrorLevel, "failed to close connections after delay"); c != nil {
|
||||
if c := h.logger.Check(zapcore.ErrorLevel, "failed to closed connections after delay"); c != nil {
|
||||
c.Write(
|
||||
zap.Error(err),
|
||||
zap.Duration("delay", delay),
|
||||
@@ -701,29 +567,11 @@ func isWebsocket(r *http.Request) bool {
|
||||
httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket")
|
||||
}
|
||||
|
||||
// closeConnectionsForUpstream closes all tracked connections that were
|
||||
// established to the given upstream address.
|
||||
func (ts *tunnelTracker) closeConnectionsForUpstream(addr string) error {
|
||||
var err error
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if !ts.stopped {
|
||||
return nil
|
||||
}
|
||||
for _, oc := range ts.connections {
|
||||
if oc.upstream != addr {
|
||||
continue
|
||||
}
|
||||
if oc.gracefulClose != nil {
|
||||
if gracefulErr := oc.gracefulClose(); gracefulErr != nil && err == nil {
|
||||
err = gracefulErr
|
||||
}
|
||||
}
|
||||
if closeErr := oc.conn.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
// openConnection maps an open connection to
|
||||
// an optional function for graceful close.
|
||||
type openConnection struct {
|
||||
conn io.ReadWriteCloser
|
||||
gracefulClose func() error
|
||||
}
|
||||
|
||||
type maxLatencyWriter struct {
|
||||
@@ -794,23 +642,16 @@ type switchProtocolCopier struct {
|
||||
user, backend io.ReadWriteCloser
|
||||
wg *sync.WaitGroup
|
||||
bufferSize int
|
||||
// sent and received accumulate byte counts for each direction.
|
||||
// They are written before wg.Done() and read after wg.Wait(), so no
|
||||
// additional synchronization is needed beyond the WaitGroup barrier.
|
||||
sent *int64 // bytes copied to backend; must be non-nil
|
||||
received *int64 // bytes copied from backend; must be non-nil
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||
n, err := io.CopyBuffer(c.user, c.backend, c.buffer())
|
||||
*c.received = n
|
||||
_, err := io.CopyBuffer(c.user, c.backend, c.buffer())
|
||||
errc <- err
|
||||
c.wg.Done()
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||
n, err := io.CopyBuffer(c.backend, c.user, c.buffer())
|
||||
*c.sent = n
|
||||
_, err := io.CopyBuffer(c.backend, c.user, c.buffer())
|
||||
errc <- err
|
||||
c.wg.Done()
|
||||
}
|
||||
|
||||
@@ -7,10 +7,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
)
|
||||
|
||||
func TestHandlerCopyResponse(t *testing.T) {
|
||||
@@ -43,15 +41,12 @@ func TestSwitchProtocolCopierBufferSize(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
var errc = make(chan error, 1)
|
||||
var dst bytes.Buffer
|
||||
var sent, received int64
|
||||
|
||||
copier := switchProtocolCopier{
|
||||
user: nopReadWriteCloser{Reader: strings.NewReader("hello")},
|
||||
backend: nopReadWriteCloser{Writer: &dst},
|
||||
wg: &wg,
|
||||
bufferSize: 7,
|
||||
sent: &sent,
|
||||
received: &received,
|
||||
}
|
||||
|
||||
buf := copier.buffer()
|
||||
@@ -85,146 +80,3 @@ type nopReadWriteCloser struct {
|
||||
}
|
||||
|
||||
func (nopReadWriteCloser) Close() error { return nil }
|
||||
|
||||
type trackingReadWriteCloser struct {
|
||||
closed chan struct{}
|
||||
one sync.Once
|
||||
}
|
||||
|
||||
func newTrackingReadWriteCloser() *trackingReadWriteCloser {
|
||||
return &trackingReadWriteCloser{closed: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (c *trackingReadWriteCloser) Read(_ []byte) (int, error) { return 0, io.EOF }
|
||||
func (c *trackingReadWriteCloser) Write(p []byte) (int, error) { return len(p), nil }
|
||||
func (c *trackingReadWriteCloser) Close() error {
|
||||
c.one.Do(func() {
|
||||
close(c.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *trackingReadWriteCloser) isClosed() bool {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerCleanupLegacyModeClosesAllConnections(t *testing.T) {
|
||||
ts := newTunnelTracker(caddy.Log(), 0)
|
||||
connA := newTrackingReadWriteCloser()
|
||||
connB := newTrackingReadWriteCloser()
|
||||
ts.registerConnection(connA, nil, false, "a")
|
||||
ts.registerConnection(connB, nil, false, "b")
|
||||
|
||||
h := &Handler{
|
||||
tunnelTracker: ts,
|
||||
StreamDetached: false,
|
||||
}
|
||||
|
||||
if err := h.Cleanup(); err != nil {
|
||||
t.Fatalf("cleanup failed: %v", err)
|
||||
}
|
||||
if !connA.isClosed() || !connB.isClosed() {
|
||||
t.Fatalf("legacy cleanup should close all upgraded connections")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerCleanupLegacyModeHonorsDelay(t *testing.T) {
|
||||
ts := newTunnelTracker(caddy.Log(), 40*time.Millisecond)
|
||||
conn := newTrackingReadWriteCloser()
|
||||
ts.registerConnection(conn, nil, false, "a")
|
||||
|
||||
h := &Handler{
|
||||
tunnelTracker: ts,
|
||||
StreamDetached: false,
|
||||
}
|
||||
|
||||
if err := h.Cleanup(); err != nil {
|
||||
t.Fatalf("cleanup failed: %v", err)
|
||||
}
|
||||
if conn.isClosed() {
|
||||
t.Fatal("connection should not close immediately when stream_close_delay is set")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-conn.closed:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("connection did not close after stream_close_delay elapsed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerCleanupDetachedModeClosesOnlyRemovedUpstreams(t *testing.T) {
|
||||
const upstreamA = "upstream-a"
|
||||
const upstreamB = "upstream-b"
|
||||
|
||||
// Simulate old+new configs both referencing upstreamA (refcount 2),
|
||||
// while upstreamB is only referenced by the old config (refcount 1).
|
||||
hosts.LoadOrStore(upstreamA, struct{}{})
|
||||
hosts.LoadOrStore(upstreamA, struct{}{})
|
||||
hosts.LoadOrStore(upstreamB, struct{}{})
|
||||
t.Cleanup(func() {
|
||||
_, _ = hosts.Delete(upstreamA)
|
||||
_, _ = hosts.Delete(upstreamA)
|
||||
_, _ = hosts.Delete(upstreamB)
|
||||
})
|
||||
|
||||
ts := newTunnelTracker(caddy.Log(), 0)
|
||||
registerDetachedTunnelTrackers(ts)
|
||||
connA := newTrackingReadWriteCloser()
|
||||
connB := newTrackingReadWriteCloser()
|
||||
ts.registerConnection(connA, nil, true, upstreamA)
|
||||
ts.registerConnection(connB, nil, true, upstreamB)
|
||||
|
||||
h := &Handler{
|
||||
tunnelTracker: ts,
|
||||
StreamDetached: true,
|
||||
Upstreams: UpstreamPool{
|
||||
&Upstream{Dial: upstreamA},
|
||||
&Upstream{Dial: upstreamB},
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.Cleanup(); err != nil {
|
||||
t.Fatalf("cleanup failed: %v", err)
|
||||
}
|
||||
|
||||
if connA.isClosed() {
|
||||
t.Fatal("connection for detached upstream should remain open")
|
||||
}
|
||||
if !connB.isClosed() {
|
||||
t.Fatal("connection for removed upstream should be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerUnmarshalCaddyfileStreamLogsBlock(t *testing.T) {
|
||||
d := caddyfile.NewTestDispenser(`
|
||||
reverse_proxy localhost:9000 {
|
||||
stream_logs {
|
||||
level info
|
||||
logger_name access
|
||||
skip_handshake
|
||||
}
|
||||
}
|
||||
`)
|
||||
|
||||
var h Handler
|
||||
if err := h.UnmarshalCaddyfile(d); err != nil {
|
||||
t.Fatalf("UnmarshalCaddyfile() error = %v", err)
|
||||
}
|
||||
if h.StreamLogs == nil {
|
||||
t.Fatal("expected stream_logs to be configured")
|
||||
}
|
||||
if h.StreamLogs.Level != "info" {
|
||||
t.Fatalf("expected stream_logs.level=info, got %q", h.StreamLogs.Level)
|
||||
}
|
||||
if h.StreamLogs.LoggerName != "access" {
|
||||
t.Fatalf("expected stream_logs.logger_name=access, got %q", h.StreamLogs.LoggerName)
|
||||
}
|
||||
if !h.StreamLogs.SkipHandshake {
|
||||
t.Fatal("expected stream_logs.skip_handshake=true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReverse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple string",
|
||||
input: "hello",
|
||||
expected: "olleh",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single character",
|
||||
input: "a",
|
||||
expected: "a",
|
||||
},
|
||||
{
|
||||
name: "two characters",
|
||||
input: "ab",
|
||||
expected: "ba",
|
||||
},
|
||||
{
|
||||
name: "palindrome",
|
||||
input: "racecar",
|
||||
expected: "racecar",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
input: "hello world",
|
||||
expected: "dlrow olleh",
|
||||
},
|
||||
{
|
||||
name: "with numbers",
|
||||
input: "abc123",
|
||||
expected: "321cba",
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
input: "hello世界",
|
||||
expected: "界世olleh",
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "🎉🎊🎈",
|
||||
expected: "🎈🎊🎉",
|
||||
},
|
||||
{
|
||||
name: "mixed unicode and ascii",
|
||||
input: "café☕",
|
||||
expected: "☕éfac",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "a!b@c#d$",
|
||||
expected: "$d#c@b!a",
|
||||
},
|
||||
{
|
||||
name: "path-like string",
|
||||
input: "/path/to/file",
|
||||
expected: "elif/ot/htap/",
|
||||
},
|
||||
{
|
||||
name: "url-like string",
|
||||
input: "https://example.com",
|
||||
expected: "moc.elpmaxe//:sptth",
|
||||
},
|
||||
{
|
||||
name: "long string",
|
||||
input: "The quick brown fox jumps over the lazy dog",
|
||||
expected: "god yzal eht revo spmuj xof nworb kciuq ehT",
|
||||
},
|
||||
{
|
||||
name: "newlines",
|
||||
input: "line1\nline2\nline3",
|
||||
expected: "3enil\n2enil\n1enil",
|
||||
},
|
||||
{
|
||||
name: "tabs",
|
||||
input: "a\tb\tc",
|
||||
expected: "c\tb\ta",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reverse(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("reverse(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
|
||||
// Test that reversing twice gives the original string
|
||||
if tt.input != "" {
|
||||
doubleReverse := reverse(reverse(tt.input))
|
||||
if doubleReverse != tt.input {
|
||||
t.Errorf("reverse(reverse(%q)) = %q; want %q", tt.input, doubleReverse, tt.input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverse_LengthPreservation(t *testing.T) {
|
||||
// Test that reverse preserves string length
|
||||
testStrings := []string{
|
||||
"",
|
||||
"a",
|
||||
"ab",
|
||||
"abc",
|
||||
"hello world",
|
||||
"🎉🎊🎈",
|
||||
"café☕",
|
||||
"The quick brown fox jumps over the lazy dog",
|
||||
}
|
||||
|
||||
for _, s := range testStrings {
|
||||
reversed := reverse(s)
|
||||
if len([]rune(s)) != len([]rune(reversed)) {
|
||||
t.Errorf("reverse(%q) changed length: original %d, reversed %d", s, len([]rune(s)), len([]rune(reversed)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkReverse benchmarks the reverse function
|
||||
func BenchmarkReverse(b *testing.B) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"short", "hello"},
|
||||
{"medium", "The quick brown fox jumps over the lazy dog"},
|
||||
{"long", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."},
|
||||
{"unicode", "hello世界🎉"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
reverse(tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverse_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"null byte", "\x00"},
|
||||
{"multiple null bytes", "\x00\x00\x00"},
|
||||
{"control characters", "\t\n\r"},
|
||||
{"high unicode", "𝕳𝖊𝖑𝖑𝖔"},
|
||||
{"zero-width characters", "a\u200Bb\u200Cc"},
|
||||
{"combining characters", "é"}, // e + combining acute
|
||||
{"rtl text", "مرحبا"},
|
||||
{"mixed rtl/ltr", "Hello مرحبا World"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reverse(tt.input)
|
||||
// Just ensure it doesn't panic and returns something
|
||||
if result == "" && tt.input != "" {
|
||||
t.Errorf("reverse(%q) returned empty string", tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
)
|
||||
|
||||
func TestStaticErrorCaddyModule(t *testing.T) {
|
||||
se := StaticError{}
|
||||
info := se.CaddyModule()
|
||||
if info.ID != "http.handlers.error" {
|
||||
t.Errorf("CaddyModule().ID = %q, want 'http.handlers.error'", info.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticErrorServeHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
staticErr StaticError
|
||||
wantStatusCode int
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "default status code 500",
|
||||
staticErr: StaticError{},
|
||||
wantStatusCode: 500,
|
||||
},
|
||||
{
|
||||
name: "custom status code",
|
||||
staticErr: StaticError{StatusCode: "404"},
|
||||
wantStatusCode: 404,
|
||||
},
|
||||
{
|
||||
name: "custom error message",
|
||||
staticErr: StaticError{Error: "custom error", StatusCode: "503"},
|
||||
wantStatusCode: 503,
|
||||
wantMessage: "custom error",
|
||||
},
|
||||
{
|
||||
name: "status code only",
|
||||
staticErr: StaticError{StatusCode: "403"},
|
||||
wantStatusCode: 403,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repl := caddy.NewReplacer()
|
||||
ctx := context.WithValue(context.Background(), caddy.ReplacerCtxKey, repl)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err := tt.staticErr.ServeHTTP(w, req, nil)
|
||||
if err == nil {
|
||||
t.Fatal("ServeHTTP() should return an error")
|
||||
}
|
||||
|
||||
var he HandlerError
|
||||
if !errors.As(err, &he) {
|
||||
t.Fatal("ServeHTTP() error should be HandlerError")
|
||||
}
|
||||
|
||||
if he.StatusCode != tt.wantStatusCode {
|
||||
t.Errorf("StatusCode = %d, want %d", he.StatusCode, tt.wantStatusCode)
|
||||
}
|
||||
|
||||
if tt.wantMessage != "" && he.Err != nil {
|
||||
if he.Err.Error() != tt.wantMessage {
|
||||
t.Errorf("Err.Error() = %q, want %q", he.Err.Error(), tt.wantMessage)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticErrorServeHTTPInvalidStatusCode(t *testing.T) {
|
||||
repl := caddy.NewReplacer()
|
||||
ctx := context.WithValue(context.Background(), caddy.ReplacerCtxKey, repl)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
se := StaticError{StatusCode: "not_a_number"}
|
||||
err := se.ServeHTTP(w, req, nil)
|
||||
if err == nil {
|
||||
t.Fatal("ServeHTTP() should return error for invalid status code")
|
||||
}
|
||||
|
||||
var he HandlerError
|
||||
if !errors.As(err, &he) {
|
||||
t.Fatal("error should be HandlerError")
|
||||
}
|
||||
// Invalid status code should return 500
|
||||
if he.StatusCode != 500 {
|
||||
t.Errorf("StatusCode = %d, want 500 for invalid status code", he.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticErrorUnmarshalCaddyfile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
wantStatus string
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "status code only",
|
||||
input: `error 404`,
|
||||
wantStatus: "404",
|
||||
},
|
||||
{
|
||||
name: "message only (non-3-digit)",
|
||||
input: `error "Page not found"`,
|
||||
wantMsg: "Page not found",
|
||||
},
|
||||
{
|
||||
name: "message and status code",
|
||||
input: `error "Page not found" 404`,
|
||||
wantStatus: "404",
|
||||
wantMsg: "Page not found",
|
||||
},
|
||||
{
|
||||
name: "no args",
|
||||
input: `error`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "too many args",
|
||||
input: `error "msg" 404 extra`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "status in block",
|
||||
input: "error 500 {\n message \"server error\"\n}",
|
||||
wantStatus: "500",
|
||||
wantMsg: "server error",
|
||||
},
|
||||
{
|
||||
name: "two-digit number is treated as message",
|
||||
input: `error 42`,
|
||||
wantMsg: "42",
|
||||
},
|
||||
{
|
||||
name: "four-digit number is treated as message",
|
||||
input: `error 1234`,
|
||||
wantMsg: "1234",
|
||||
},
|
||||
{
|
||||
name: "three-digit is status code",
|
||||
input: `error 503`,
|
||||
wantStatus: "503",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := caddyfile.NewTestDispenser(tt.input)
|
||||
se := &StaticError{}
|
||||
err := se.UnmarshalCaddyfile(d)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UnmarshalCaddyfile() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantStatus != "" && string(se.StatusCode) != tt.wantStatus {
|
||||
t.Errorf("StatusCode = %q, want %q", se.StatusCode, tt.wantStatus)
|
||||
}
|
||||
if tt.wantMsg != "" && se.Error != tt.wantMsg {
|
||||
t.Errorf("Error = %q, want %q", se.Error, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticErrorUnmarshalCaddyfileDuplicateMessage(t *testing.T) {
|
||||
input := "error \"first message\" 500 {\n message \"second message\"\n}"
|
||||
d := caddyfile.NewTestDispenser(input)
|
||||
se := &StaticError{}
|
||||
err := se.UnmarshalCaddyfile(d)
|
||||
if err == nil {
|
||||
t.Error("expected error when message is specified both inline and in block")
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,158 @@ import (
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
)
|
||||
|
||||
func TestGetVarAndSetVar(t *testing.T) {
|
||||
vars := map[string]any{
|
||||
"existing_key": "existing_value",
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), VarsCtxKey, vars)
|
||||
|
||||
if v := GetVar(ctx, "existing_key"); v != "existing_value" {
|
||||
t.Errorf("GetVar() = %v, want 'existing_value'", v)
|
||||
}
|
||||
|
||||
if v := GetVar(ctx, "nonexistent_key"); v != nil {
|
||||
t.Errorf("GetVar() for missing key = %v, want nil", v)
|
||||
}
|
||||
|
||||
emptyCtx := context.Background()
|
||||
if v := GetVar(emptyCtx, "any"); v != nil {
|
||||
t.Errorf("GetVar() on context without vars = %v, want nil", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetVar(t *testing.T) {
|
||||
vars := map[string]any{}
|
||||
ctx := context.WithValue(context.Background(), VarsCtxKey, vars)
|
||||
|
||||
SetVar(ctx, "key1", "value1")
|
||||
if vars["key1"] != "value1" {
|
||||
t.Errorf("SetVar() didn't set value, got %v", vars["key1"])
|
||||
}
|
||||
|
||||
SetVar(ctx, "key1", "value2")
|
||||
if vars["key1"] != "value2" {
|
||||
t.Errorf("SetVar() didn't overwrite value, got %v", vars["key1"])
|
||||
}
|
||||
|
||||
SetVar(ctx, "key1", nil)
|
||||
if _, ok := vars["key1"]; ok {
|
||||
t.Error("SetVar(nil) should delete the key")
|
||||
}
|
||||
|
||||
// BUG: SetVar with nil for non-existent key should be a no-op per its documentation,
|
||||
// but it actually inserts a nil value into the map. The nil check only deletes
|
||||
// existing keys; if the key doesn't exist, execution falls through to the
|
||||
// final `varMap[key] = value` line, storing nil.
|
||||
SetVar(ctx, "nonexistent", nil)
|
||||
if _, ok := vars["nonexistent"]; !ok {
|
||||
t.Error("BUG: SetVar(nil) for non-existent key unexpectedly did NOT set the key. If this passes, the bug described in code comments may have been fixed.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetVarWithoutContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
SetVar(ctx, "key", "value")
|
||||
}
|
||||
|
||||
func TestVarsMiddlewareCaddyModule(t *testing.T) {
|
||||
m := VarsMiddleware{}
|
||||
info := m.CaddyModule()
|
||||
if info.ID != "http.handlers.vars" {
|
||||
t.Errorf("CaddyModule().ID = %v, want 'http.handlers.vars'", info.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVarsMatcherEmptyMatch(t *testing.T) {
|
||||
m := VarsMatcher{}
|
||||
|
||||
vars := map[string]any{}
|
||||
repl := caddy.NewReplacer()
|
||||
ctx := context.WithValue(context.Background(), VarsCtxKey, vars)
|
||||
ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, repl)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
match, err := m.MatchWithError(req)
|
||||
if err != nil {
|
||||
t.Fatalf("MatchWithError() error = %v", err)
|
||||
}
|
||||
if !match {
|
||||
t.Error("empty VarsMatcher should match everything")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVarsMatcherMatch(t *testing.T) {
|
||||
vars := map[string]any{
|
||||
"my_var": "hello",
|
||||
}
|
||||
repl := caddy.NewReplacer()
|
||||
ctx := context.WithValue(context.Background(), VarsCtxKey, vars)
|
||||
ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, repl)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
matcher VarsMatcher
|
||||
wantMatch bool
|
||||
}{
|
||||
{
|
||||
name: "matching variable",
|
||||
matcher: VarsMatcher{"my_var": {"hello"}},
|
||||
wantMatch: true,
|
||||
},
|
||||
{
|
||||
name: "non-matching variable",
|
||||
matcher: VarsMatcher{"my_var": {"world"}},
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent variable",
|
||||
matcher: VarsMatcher{"nonexistent": {"anything"}},
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "multiple values OR",
|
||||
matcher: VarsMatcher{"my_var": {"world", "hello", "foo"}},
|
||||
wantMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
match := tt.matcher.Match(req)
|
||||
if match != tt.wantMatch {
|
||||
t.Errorf("Match() = %v, want %v", match, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVarsMatcherWithNilVarValue(t *testing.T) {
|
||||
vars := map[string]any{
|
||||
"nil_var": nil,
|
||||
}
|
||||
repl := caddy.NewReplacer()
|
||||
ctx := context.WithValue(context.Background(), VarsCtxKey, vars)
|
||||
ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, repl)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
m := VarsMatcher{"nil_var": {""}}
|
||||
match, err := m.MatchWithError(req)
|
||||
if err != nil {
|
||||
t.Fatalf("MatchWithError() error = %v", err)
|
||||
}
|
||||
if !match {
|
||||
t.Error("nil variable value should match empty string")
|
||||
}
|
||||
}
|
||||
|
||||
func newVarsTestRequest(t *testing.T, target string, headers http.Header, vars map[string]any) (*http.Request, *caddy.Replacer) {
|
||||
t.Helper()
|
||||
|
||||
@@ -38,8 +190,6 @@ func newVarsTestRequest(t *testing.T, target string, headers http.Header, vars m
|
||||
if vars == nil {
|
||||
vars = make(map[string]any)
|
||||
}
|
||||
// Inject vars directly so these tests exercise matcher-side handling of
|
||||
// already-resolved values, not VarsMiddleware placeholder expansion.
|
||||
ctx = context.WithValue(ctx, VarsCtxKey, vars)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
|
||||
@@ -896,18 +896,19 @@ func (clientauth *ClientAuthentication) ConfigureTLSConfig(cfg *tls.Config) erro
|
||||
// Unlike VerifyPeerCertificate, VerifyConnection is called on every
|
||||
// connection including resumed sessions, preventing session-resumption bypass.
|
||||
func (clientauth *ClientAuthentication) verifyConnection(cs tls.ConnectionState) error {
|
||||
rawCerts := make([][]byte, len(cs.PeerCertificates))
|
||||
for i, cert := range cs.PeerCertificates {
|
||||
rawCerts[i] = cert.Raw
|
||||
}
|
||||
|
||||
// first use any pre-existing custom verification function
|
||||
if clientauth.existingVerifyPeerCert != nil {
|
||||
rawCerts := make([][]byte, len(cs.PeerCertificates))
|
||||
for i, cert := range cs.PeerCertificates {
|
||||
rawCerts[i] = cert.Raw
|
||||
}
|
||||
if err := clientauth.existingVerifyPeerCert(rawCerts, cs.VerifiedChains); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, verifier := range clientauth.verifiers {
|
||||
if err := verifier.VerifyClientCertificate(nil, cs.VerifiedChains); err != nil {
|
||||
if err := verifier.VerifyClientCertificate(rawCerts, cs.VerifiedChains); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testClientCertificateVerifier struct {
|
||||
rawCerts [][]byte
|
||||
verifiedChains [][]*x509.Certificate
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *testClientCertificateVerifier) VerifyClientCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
v.rawCerts = rawCerts
|
||||
v.verifiedChains = verifiedChains
|
||||
return v.err
|
||||
}
|
||||
|
||||
func TestClientAuthenticationVerifyConnectionPassesRawCertsToVerifiers(t *testing.T) {
|
||||
verifier := &testClientCertificateVerifier{}
|
||||
clientauth := &ClientAuthentication{
|
||||
verifiers: []ClientCertificateVerifier{verifier},
|
||||
}
|
||||
|
||||
peerCert := &x509.Certificate{Raw: []byte("peer-cert-raw")}
|
||||
verifiedChains := [][]*x509.Certificate{{peerCert}}
|
||||
connState := tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{peerCert},
|
||||
VerifiedChains: verifiedChains,
|
||||
}
|
||||
|
||||
if err := clientauth.verifyConnection(connState); err != nil {
|
||||
t.Fatalf("verifyConnection failed: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(verifier.rawCerts, [][]byte{[]byte("peer-cert-raw")}) {
|
||||
t.Fatalf("unexpected raw certs: got %#v", verifier.rawCerts)
|
||||
}
|
||||
if !reflect.DeepEqual(verifier.verifiedChains, verifiedChains) {
|
||||
t.Fatalf("unexpected verified chains: got %#v", verifier.verifiedChains)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientAuthenticationVerifyConnectionReturnsVerifierError(t *testing.T) {
|
||||
wantErr := errors.New("verify failed")
|
||||
verifier := &testClientCertificateVerifier{err: wantErr}
|
||||
clientauth := &ClientAuthentication{
|
||||
verifiers: []ClientCertificateVerifier{verifier},
|
||||
}
|
||||
|
||||
err := clientauth.verifyConnection(tls.ConnectionState{})
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("expected error %v, got %v", wantErr, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package filestorage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
)
|
||||
|
||||
func TestFileStorageCaddyModule(t *testing.T) {
|
||||
fs := FileStorage{}
|
||||
info := fs.CaddyModule()
|
||||
if info.ID != "caddy.storage.file_system" {
|
||||
t.Errorf("CaddyModule().ID = %q, want 'caddy.storage.file_system'", info.ID)
|
||||
}
|
||||
mod := info.New()
|
||||
if mod == nil {
|
||||
t.Error("New() should not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorageCertMagicStorage(t *testing.T) {
|
||||
fs := FileStorage{Root: "/var/lib/caddy/certs"}
|
||||
storage, err := fs.CertMagicStorage()
|
||||
if err != nil {
|
||||
t.Fatalf("CertMagicStorage() error = %v", err)
|
||||
}
|
||||
if storage == nil {
|
||||
t.Fatal("CertMagicStorage() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorageCertMagicStorageEmptyRoot(t *testing.T) {
|
||||
fs := FileStorage{Root: ""}
|
||||
storage, err := fs.CertMagicStorage()
|
||||
if err != nil {
|
||||
t.Fatalf("CertMagicStorage() error = %v", err)
|
||||
}
|
||||
if storage == nil {
|
||||
t.Fatal("CertMagicStorage() returned nil even with empty root")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorageUnmarshalCaddyfile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
wantVal string
|
||||
}{
|
||||
{
|
||||
name: "root as inline arg",
|
||||
input: `file_system /var/lib/caddy`,
|
||||
wantVal: "/var/lib/caddy",
|
||||
},
|
||||
{
|
||||
name: "root in block",
|
||||
input: "file_system {\n\troot /var/lib/caddy\n}",
|
||||
wantVal: "/var/lib/caddy",
|
||||
},
|
||||
{
|
||||
name: "missing root",
|
||||
input: `file_system`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "too many inline args",
|
||||
input: `file_system /path1 /path2`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "root already set inline then block",
|
||||
input: "file_system /path1 {\n\troot /path2\n}",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown subdirective",
|
||||
input: "file_system {\n\tunknown_option value\n}",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "root in block without value",
|
||||
input: "file_system {\n\troot\n}",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := caddyfile.NewTestDispenser(tt.input)
|
||||
fs := &FileStorage{}
|
||||
err := fs.UnmarshalCaddyfile(d)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UnmarshalCaddyfile() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && fs.Root != tt.wantVal {
|
||||
t.Errorf("Root = %q, want %q", fs.Root, tt.wantVal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+963
@@ -0,0 +1,963 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNetworkAddress_String_Consistency(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
}{
|
||||
{
|
||||
name: "basic tcp",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "localhost", StartPort: 8080, EndPort: 8080},
|
||||
},
|
||||
{
|
||||
name: "tcp with port range",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "localhost", StartPort: 8080, EndPort: 8090},
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
addr: NetworkAddress{Network: "unix", Host: "/tmp/socket"},
|
||||
},
|
||||
{
|
||||
name: "udp",
|
||||
addr: NetworkAddress{Network: "udp", Host: "0.0.0.0", StartPort: 53, EndPort: 53},
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "::1", StartPort: 80, EndPort: 80},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
str := test.addr.String()
|
||||
|
||||
// Parse the string back
|
||||
parsed, err := ParseNetworkAddress(str)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse string representation: %v", err)
|
||||
}
|
||||
|
||||
// Should be equivalent to original
|
||||
if parsed.Network != test.addr.Network {
|
||||
t.Errorf("Network mismatch: expected %s, got %s", test.addr.Network, parsed.Network)
|
||||
}
|
||||
if parsed.Host != test.addr.Host {
|
||||
t.Errorf("Host mismatch: expected %s, got %s", test.addr.Host, parsed.Host)
|
||||
}
|
||||
if parsed.StartPort != test.addr.StartPort {
|
||||
t.Errorf("StartPort mismatch: expected %d, got %d", test.addr.StartPort, parsed.StartPort)
|
||||
}
|
||||
if parsed.EndPort != test.addr.EndPort {
|
||||
t.Errorf("EndPort mismatch: expected %d, got %d", test.addr.EndPort, parsed.EndPort)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_PortRangeSize_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
expected uint
|
||||
}{
|
||||
{
|
||||
name: "single port",
|
||||
addr: NetworkAddress{StartPort: 80, EndPort: 80},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "invalid range (end < start)",
|
||||
addr: NetworkAddress{StartPort: 8080, EndPort: 8070},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "zero ports",
|
||||
addr: NetworkAddress{StartPort: 0, EndPort: 0},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "maximum range",
|
||||
addr: NetworkAddress{StartPort: 1, EndPort: 65535},
|
||||
expected: 65535,
|
||||
},
|
||||
{
|
||||
name: "large range",
|
||||
addr: NetworkAddress{StartPort: 8000, EndPort: 9000},
|
||||
expected: 1001,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
size := test.addr.PortRangeSize()
|
||||
if size != test.expected {
|
||||
t.Errorf("Expected %d, got %d", test.expected, size)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_At_Validation(t *testing.T) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8080,
|
||||
EndPort: 8090,
|
||||
}
|
||||
|
||||
// Test valid offsets
|
||||
for offset := uint(0); offset <= 10; offset++ {
|
||||
result := addr.At(offset)
|
||||
expectedPort := 8080 + offset
|
||||
|
||||
if result.StartPort != expectedPort || result.EndPort != expectedPort {
|
||||
t.Errorf("Offset %d: expected port %d, got %d-%d",
|
||||
offset, expectedPort, result.StartPort, result.EndPort)
|
||||
}
|
||||
|
||||
if result.Network != addr.Network || result.Host != addr.Host {
|
||||
t.Errorf("Offset %d: network/host should be preserved", offset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_Expand_LargeRange(t *testing.T) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8000,
|
||||
EndPort: 8010,
|
||||
}
|
||||
|
||||
expanded := addr.Expand()
|
||||
expectedSize := 11 // 8000 to 8010 inclusive
|
||||
|
||||
if len(expanded) != expectedSize {
|
||||
t.Errorf("Expected %d addresses, got %d", expectedSize, len(expanded))
|
||||
}
|
||||
|
||||
// Verify each address
|
||||
for i, expandedAddr := range expanded {
|
||||
expectedPort := uint(8000 + i)
|
||||
if expandedAddr.StartPort != expectedPort || expandedAddr.EndPort != expectedPort {
|
||||
t.Errorf("Address %d: expected port %d, got %d-%d",
|
||||
i, expectedPort, expandedAddr.StartPort, expandedAddr.EndPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_IsLoopback_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "unix socket",
|
||||
addr: NetworkAddress{Network: "unix", Host: "/tmp/socket"},
|
||||
expected: true, // Unix sockets are always considered loopback
|
||||
},
|
||||
{
|
||||
name: "fd network",
|
||||
addr: NetworkAddress{Network: "fd", Host: "3"},
|
||||
expected: true, // fd networks are always considered loopback
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "localhost"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "127.0.0.1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "::1",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "::1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.2",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "127.0.0.2"},
|
||||
expected: true, // Part of 127.0.0.0/8 loopback range
|
||||
},
|
||||
{
|
||||
name: "192.168.1.1",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "192.168.1.1"},
|
||||
expected: false, // Private but not loopback
|
||||
},
|
||||
{
|
||||
name: "invalid ip",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "invalid-ip"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty host",
|
||||
addr: NetworkAddress{Network: "tcp", Host: ""},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.addr.isLoopback()
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_IsWildcard_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty host",
|
||||
addr: NetworkAddress{Network: "tcp", Host: ""},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ipv4 any",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "0.0.0.0"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 any",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "::"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "localhost"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "specific ip",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "192.168.1.1"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid ip",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "invalid"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.addr.isWildcardInterface()
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitNetworkAddress_IPv6_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectNetwork string
|
||||
expectHost string
|
||||
expectPort string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "ipv6 with port",
|
||||
input: "[::1]:8080",
|
||||
expectHost: "::1",
|
||||
expectPort: "8080",
|
||||
},
|
||||
{
|
||||
name: "ipv6 without port",
|
||||
input: "[::1]",
|
||||
expectHost: "::1",
|
||||
},
|
||||
{
|
||||
name: "ipv6 without brackets or port",
|
||||
input: "::1",
|
||||
expectHost: "::1",
|
||||
},
|
||||
{
|
||||
name: "ipv6 loopback",
|
||||
input: "[::1]:443",
|
||||
expectHost: "::1",
|
||||
expectPort: "443",
|
||||
},
|
||||
{
|
||||
name: "ipv6 any address",
|
||||
input: "[::]:80",
|
||||
expectHost: "::",
|
||||
expectPort: "80",
|
||||
},
|
||||
{
|
||||
name: "ipv6 with network prefix",
|
||||
input: "tcp6/[::1]:8080",
|
||||
expectNetwork: "tcp6",
|
||||
expectHost: "::1",
|
||||
expectPort: "8080",
|
||||
},
|
||||
{
|
||||
name: "malformed ipv6",
|
||||
input: "[::1:8080", // Missing closing bracket
|
||||
expectHost: "::1:8080",
|
||||
},
|
||||
{
|
||||
name: "ipv6 with zone",
|
||||
input: "[fe80::1%eth0]:8080",
|
||||
expectHost: "fe80::1%eth0",
|
||||
expectPort: "8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
network, host, port, err := SplitNetworkAddress(test.input)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if network != test.expectNetwork {
|
||||
t.Errorf("Network: expected '%s', got '%s'", test.expectNetwork, network)
|
||||
}
|
||||
if host != test.expectHost {
|
||||
t.Errorf("Host: expected '%s', got '%s'", test.expectHost, host)
|
||||
}
|
||||
if port != test.expectPort {
|
||||
t.Errorf("Port: expected '%s', got '%s'", test.expectPort, port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNetworkAddress_PortRange_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid range",
|
||||
input: "localhost:8080-8090",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "inverted range",
|
||||
input: "localhost:8090-8080",
|
||||
expectErr: true,
|
||||
errMsg: "end port must not be less than start port",
|
||||
},
|
||||
{
|
||||
name: "too large range",
|
||||
input: "localhost:0-65535",
|
||||
expectErr: true,
|
||||
errMsg: "port range exceeds 65535 ports",
|
||||
},
|
||||
{
|
||||
name: "invalid start port",
|
||||
input: "localhost:abc-8080",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid end port",
|
||||
input: "localhost:8080-xyz",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "port too large",
|
||||
input: "localhost:99999",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative port",
|
||||
input: "localhost:-80",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
_, err := ParseNetworkAddress(test.input)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if test.expectErr && test.errMsg != "" && err != nil {
|
||||
if !containsString(err.Error(), test.errMsg) {
|
||||
t.Errorf("Expected error containing '%s', got '%s'", test.errMsg, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_Listen_ContextCancellation(t *testing.T) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 0, // Let OS assign port
|
||||
EndPort: 0,
|
||||
}
|
||||
|
||||
// Create context that will be cancelled
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Start listening in a goroutine
|
||||
listenDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := addr.Listen(ctx, 0, net.ListenConfig{})
|
||||
listenDone <- err
|
||||
}()
|
||||
|
||||
// Cancel context immediately
|
||||
cancel()
|
||||
|
||||
// Should get context cancellation error quickly
|
||||
select {
|
||||
case err := <-listenDone:
|
||||
if err == nil {
|
||||
t.Error("Expected error due to context cancellation")
|
||||
}
|
||||
// Accept any error related to context cancellation
|
||||
// (could be context.Canceled or DNS lookup error due to cancellation)
|
||||
case <-time.After(time.Second):
|
||||
t.Error("Listen operation did not respect context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_ListenAll_PartialFailure(t *testing.T) {
|
||||
// Create an address range where some ports might fail to bind
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 0, // OS-assigned port
|
||||
EndPort: 2, // Try to bind 3 ports starting from OS-assigned
|
||||
}
|
||||
|
||||
// This test might be flaky depending on available ports,
|
||||
// but tests the error handling logic
|
||||
ctx := context.Background()
|
||||
|
||||
listeners, err := addr.ListenAll(ctx, net.ListenConfig{})
|
||||
|
||||
// Either all succeed or all fail (due to cleanup on partial failure)
|
||||
if err != nil {
|
||||
// If there's an error, no listeners should be returned
|
||||
if len(listeners) != 0 {
|
||||
t.Errorf("Expected no listeners on error, got %d", len(listeners))
|
||||
}
|
||||
} else {
|
||||
// If successful, should have listeners for all ports in range
|
||||
expectedCount := int(addr.PortRangeSize())
|
||||
if len(listeners) != expectedCount {
|
||||
t.Errorf("Expected %d listeners, got %d", expectedCount, len(listeners))
|
||||
}
|
||||
|
||||
// Clean up listeners
|
||||
for _, ln := range listeners {
|
||||
if closer, ok := ln.(interface{ Close() error }); ok {
|
||||
closer.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinNetworkAddress_SpecialCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
host string
|
||||
port string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty everything",
|
||||
network: "",
|
||||
host: "",
|
||||
port: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "network only",
|
||||
network: "tcp",
|
||||
host: "",
|
||||
port: "",
|
||||
expected: "tcp/",
|
||||
},
|
||||
{
|
||||
name: "host only",
|
||||
network: "",
|
||||
host: "localhost",
|
||||
port: "",
|
||||
expected: "localhost",
|
||||
},
|
||||
{
|
||||
name: "port only",
|
||||
network: "",
|
||||
host: "",
|
||||
port: "8080",
|
||||
expected: ":8080",
|
||||
},
|
||||
{
|
||||
name: "unix socket with port (port ignored)",
|
||||
network: "unix",
|
||||
host: "/tmp/socket",
|
||||
port: "8080",
|
||||
expected: "unix//tmp/socket",
|
||||
},
|
||||
{
|
||||
name: "fd network with port (port ignored)",
|
||||
network: "fd",
|
||||
host: "3",
|
||||
port: "8080",
|
||||
expected: "fd/3",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with port",
|
||||
network: "tcp",
|
||||
host: "::1",
|
||||
port: "8080",
|
||||
expected: "tcp/[::1]:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := JoinNetworkAddress(test.network, test.host, test.port)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsUnixNetwork_IsFdNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
network string
|
||||
isUnix bool
|
||||
isFd bool
|
||||
}{
|
||||
{"unix", true, false},
|
||||
{"unixgram", true, false},
|
||||
{"unixpacket", true, false},
|
||||
{"fd", false, true},
|
||||
{"fdgram", false, true},
|
||||
{"tcp", false, false},
|
||||
{"udp", false, false},
|
||||
{"", false, false},
|
||||
{"unix-like", true, false},
|
||||
{"fd-like", false, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.network, func(t *testing.T) {
|
||||
if IsUnixNetwork(test.network) != test.isUnix {
|
||||
t.Errorf("IsUnixNetwork('%s'): expected %v, got %v",
|
||||
test.network, test.isUnix, IsUnixNetwork(test.network))
|
||||
}
|
||||
if IsFdNetwork(test.network) != test.isFd {
|
||||
t.Errorf("IsFdNetwork('%s'): expected %v, got %v",
|
||||
test.network, test.isFd, IsFdNetwork(test.network))
|
||||
}
|
||||
|
||||
// Test NetworkAddress methods too
|
||||
addr := NetworkAddress{Network: test.network}
|
||||
if addr.IsUnixNetwork() != test.isUnix {
|
||||
t.Errorf("NetworkAddress.IsUnixNetwork(): expected %v, got %v",
|
||||
test.isUnix, addr.IsUnixNetwork())
|
||||
}
|
||||
if addr.IsFdNetwork() != test.isFd {
|
||||
t.Errorf("NetworkAddress.IsFdNetwork(): expected %v, got %v",
|
||||
test.isFd, addr.IsFdNetwork())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterNetwork_Validation(t *testing.T) {
|
||||
// Save original state
|
||||
originalNetworkTypes := make(map[string]ListenerFunc)
|
||||
for k, v := range networkTypes {
|
||||
originalNetworkTypes[k] = v
|
||||
}
|
||||
defer func() {
|
||||
// Restore original state
|
||||
networkTypes = originalNetworkTypes
|
||||
}()
|
||||
|
||||
mockListener := func(ctx context.Context, network, host, portRange string, portOffset uint, cfg net.ListenConfig) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Test reserved network types that should panic
|
||||
reservedTypes := []string{
|
||||
"tcp", "tcp4", "tcp6",
|
||||
"udp", "udp4", "udp6",
|
||||
"unix", "unixpacket", "unixgram",
|
||||
"ip:1", "ip4:1", "ip6:1",
|
||||
"fd", "fdgram",
|
||||
}
|
||||
|
||||
for _, networkType := range reservedTypes {
|
||||
t.Run("reserved_"+networkType, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("Expected panic for reserved network type: %s", networkType)
|
||||
}
|
||||
}()
|
||||
RegisterNetwork(networkType, mockListener)
|
||||
})
|
||||
}
|
||||
|
||||
// Test valid registration
|
||||
t.Run("valid_registration", func(t *testing.T) {
|
||||
customNetwork := "custom-network"
|
||||
RegisterNetwork(customNetwork, mockListener)
|
||||
|
||||
if _, exists := networkTypes[customNetwork]; !exists {
|
||||
t.Error("Custom network should be registered")
|
||||
}
|
||||
})
|
||||
|
||||
// Test duplicate registration should panic
|
||||
t.Run("duplicate_registration", func(t *testing.T) {
|
||||
customNetwork := "another-custom"
|
||||
RegisterNetwork(customNetwork, mockListener)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("Expected panic for duplicate registration")
|
||||
}
|
||||
}()
|
||||
RegisterNetwork(customNetwork, mockListener)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenerUsage_EdgeCases(t *testing.T) {
|
||||
// Test ListenerUsage function with various inputs
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
addr string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "non-existent listener",
|
||||
network: "tcp",
|
||||
addr: "localhost:9999",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "empty network and address",
|
||||
network: "",
|
||||
addr: "",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
network: "unix",
|
||||
addr: "/tmp/non-existent.sock",
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
usage := ListenerUsage(test.network, test.addr)
|
||||
if usage != test.expected {
|
||||
t.Errorf("Expected usage %d, got %d", test.expected, usage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_Port_Formatting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single port",
|
||||
addr: NetworkAddress{StartPort: 80, EndPort: 80},
|
||||
expected: "80",
|
||||
},
|
||||
{
|
||||
name: "port range",
|
||||
addr: NetworkAddress{StartPort: 8080, EndPort: 8090},
|
||||
expected: "8080-8090",
|
||||
},
|
||||
{
|
||||
name: "zero ports",
|
||||
addr: NetworkAddress{StartPort: 0, EndPort: 0},
|
||||
expected: "0",
|
||||
},
|
||||
{
|
||||
name: "large ports",
|
||||
addr: NetworkAddress{StartPort: 65534, EndPort: 65535},
|
||||
expected: "65534-65535",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.addr.port()
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_JoinHostPort_SpecialNetworks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
offset uint
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "unix socket ignores offset",
|
||||
addr: NetworkAddress{
|
||||
Network: "unix",
|
||||
Host: "/tmp/socket",
|
||||
},
|
||||
offset: 100,
|
||||
expected: "/tmp/socket",
|
||||
},
|
||||
{
|
||||
name: "fd network ignores offset",
|
||||
addr: NetworkAddress{
|
||||
Network: "fd",
|
||||
Host: "3",
|
||||
},
|
||||
offset: 50,
|
||||
expected: "3",
|
||||
},
|
||||
{
|
||||
name: "tcp with offset",
|
||||
addr: NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8000,
|
||||
},
|
||||
offset: 10,
|
||||
expected: "localhost:8010",
|
||||
},
|
||||
{
|
||||
name: "ipv6 with offset",
|
||||
addr: NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "::1",
|
||||
StartPort: 8000,
|
||||
},
|
||||
offset: 5,
|
||||
expected: "[::1]:8005",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.addr.JoinHostPort(test.offset)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for string containment check
|
||||
func containsString(haystack, needle string) bool {
|
||||
return len(haystack) >= len(needle) &&
|
||||
(needle == "" || haystack == needle ||
|
||||
strings.Contains(haystack, needle))
|
||||
}
|
||||
|
||||
func TestListenerKey_Generation(t *testing.T) {
|
||||
tests := []struct {
|
||||
network string
|
||||
addr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
network: "tcp",
|
||||
addr: "localhost:8080",
|
||||
expected: "tcp/localhost:8080",
|
||||
},
|
||||
{
|
||||
network: "unix",
|
||||
addr: "/tmp/socket",
|
||||
expected: "unix//tmp/socket",
|
||||
},
|
||||
{
|
||||
network: "",
|
||||
addr: "localhost:8080",
|
||||
expected: "/localhost:8080",
|
||||
},
|
||||
{
|
||||
network: "tcp",
|
||||
addr: "",
|
||||
expected: "tcp/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s_%s", test.network, test.addr), func(t *testing.T) {
|
||||
result := listenerKey(test.network, test.addr)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_ConcurrentAccess(t *testing.T) {
|
||||
// Test that NetworkAddress methods are safe for concurrent read access
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8080,
|
||||
EndPort: 8090,
|
||||
}
|
||||
|
||||
const numGoroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Call various methods concurrently
|
||||
_ = addr.String()
|
||||
_ = addr.PortRangeSize()
|
||||
_ = addr.IsUnixNetwork()
|
||||
_ = addr.IsFdNetwork()
|
||||
_ = addr.isLoopback()
|
||||
_ = addr.isWildcardInterface()
|
||||
_ = addr.port()
|
||||
_ = addr.JoinHostPort(uint(id % 10))
|
||||
_ = addr.At(uint(id % 11))
|
||||
|
||||
// Expand creates new slice, should be safe
|
||||
expanded := addr.Expand()
|
||||
if len(expanded) == 0 {
|
||||
t.Errorf("Goroutine %d: Expected non-empty expansion", id)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestNetworkAddress_IPv6_Zone_Handling(t *testing.T) {
|
||||
// Test IPv6 addresses with zone identifiers
|
||||
input := "tcp/[fe80::1%eth0]:8080"
|
||||
|
||||
addr, err := ParseNetworkAddress(input)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse IPv6 with zone: %v", err)
|
||||
}
|
||||
|
||||
if addr.Network != "tcp" {
|
||||
t.Errorf("Expected network 'tcp', got '%s'", addr.Network)
|
||||
}
|
||||
if addr.Host != "fe80::1%eth0" {
|
||||
t.Errorf("Expected host 'fe80::1%%eth0', got '%s'", addr.Host)
|
||||
}
|
||||
if addr.StartPort != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", addr.StartPort)
|
||||
}
|
||||
|
||||
// Test string representation round-trip
|
||||
str := addr.String()
|
||||
parsed, err := ParseNetworkAddress(str)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse string representation: %v", err)
|
||||
}
|
||||
|
||||
if parsed.Host != addr.Host {
|
||||
t.Errorf("Round-trip failed: expected host '%s', got '%s'", addr.Host, parsed.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseNetworkAddress(b *testing.B) {
|
||||
inputs := []string{
|
||||
"localhost:8080",
|
||||
"tcp/localhost:8080-8090",
|
||||
"unix//tmp/socket",
|
||||
"[::1]:443",
|
||||
"udp/:53",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
input := inputs[i%len(inputs)]
|
||||
ParseNetworkAddress(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNetworkAddress_String(b *testing.B) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8080,
|
||||
EndPort: 8090,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
addr.String()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNetworkAddress_Expand(b *testing.B) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8000,
|
||||
EndPort: 8100, // 101 addresses
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
addr.Expand()
|
||||
}
|
||||
}
|
||||
+1113
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,624 @@
|
||||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockDestructor struct {
|
||||
value string
|
||||
destroyed int32
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockDestructor) Destruct() error {
|
||||
atomic.StoreInt32(&m.destroyed, 1)
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockDestructor) IsDestroyed() bool {
|
||||
return atomic.LoadInt32(&m.destroyed) == 1
|
||||
}
|
||||
|
||||
func TestUsagePool_LoadOrNew_Basic(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
|
||||
// First load should construct new value
|
||||
val, loaded, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
return &mockDestructor{value: "test-value"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("Expected loaded to be false for new value")
|
||||
}
|
||||
if val.(*mockDestructor).value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got '%s'", val.(*mockDestructor).value)
|
||||
}
|
||||
|
||||
// Second load should return existing value
|
||||
val2, loaded2, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
t.Error("Constructor should not be called for existing value")
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if !loaded2 {
|
||||
t.Error("Expected loaded to be true for existing value")
|
||||
}
|
||||
if val2.(*mockDestructor).value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got '%s'", val2.(*mockDestructor).value)
|
||||
}
|
||||
|
||||
// Check reference count
|
||||
refs, exists := pool.References(key)
|
||||
if !exists {
|
||||
t.Error("Key should exist in pool")
|
||||
}
|
||||
if refs != 2 {
|
||||
t.Errorf("Expected 2 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_LoadOrNew_ConstructorError(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
expectedErr := errors.New("constructor failed")
|
||||
|
||||
val, loaded, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
return nil, expectedErr
|
||||
})
|
||||
if err != expectedErr {
|
||||
t.Errorf("Expected constructor error, got: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("Expected loaded to be false for failed construction")
|
||||
}
|
||||
if val != nil {
|
||||
t.Error("Expected nil value for failed construction")
|
||||
}
|
||||
|
||||
// Key should not exist after constructor failure
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
t.Error("Key should not exist after constructor failure")
|
||||
}
|
||||
if refs != 0 {
|
||||
t.Errorf("Expected 0 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_LoadOrStore_Basic(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
mockVal := &mockDestructor{value: "stored-value"}
|
||||
|
||||
// First load/store should store new value
|
||||
val, loaded := pool.LoadOrStore(key, mockVal)
|
||||
if loaded {
|
||||
t.Error("Expected loaded to be false for new value")
|
||||
}
|
||||
if val != mockVal {
|
||||
t.Error("Expected stored value to be returned")
|
||||
}
|
||||
|
||||
// Second load/store should return existing value
|
||||
newMockVal := &mockDestructor{value: "new-value"}
|
||||
val2, loaded2 := pool.LoadOrStore(key, newMockVal)
|
||||
if !loaded2 {
|
||||
t.Error("Expected loaded to be true for existing value")
|
||||
}
|
||||
if val2 != mockVal {
|
||||
t.Error("Expected original stored value to be returned")
|
||||
}
|
||||
|
||||
// Check reference count
|
||||
refs, exists := pool.References(key)
|
||||
if !exists {
|
||||
t.Error("Key should exist in pool")
|
||||
}
|
||||
if refs != 2 {
|
||||
t.Errorf("Expected 2 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Delete_Basic(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
mockVal := &mockDestructor{value: "test-value"}
|
||||
|
||||
// Store value twice to get ref count of 2
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
|
||||
// First delete should decrement ref count
|
||||
deleted, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if deleted {
|
||||
t.Error("Expected deleted to be false when refs > 0")
|
||||
}
|
||||
if mockVal.IsDestroyed() {
|
||||
t.Error("Value should not be destroyed yet")
|
||||
}
|
||||
|
||||
// Second delete should destroy value
|
||||
deleted, err = pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if !deleted {
|
||||
t.Error("Expected deleted to be true when refs = 0")
|
||||
}
|
||||
if !mockVal.IsDestroyed() {
|
||||
t.Error("Value should be destroyed")
|
||||
}
|
||||
|
||||
// Key should not exist after deletion
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
t.Error("Key should not exist after deletion")
|
||||
}
|
||||
if refs != 0 {
|
||||
t.Errorf("Expected 0 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Delete_NonExistentKey(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
|
||||
deleted, err := pool.Delete("non-existent")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for non-existent key, got: %v", err)
|
||||
}
|
||||
if deleted {
|
||||
t.Error("Expected deleted to be false for non-existent key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Delete_PanicOnNegativeRefs(t *testing.T) {
|
||||
// This test demonstrates the panic condition by manipulating
|
||||
// the ref count directly to create an invalid state
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
mockVal := &mockDestructor{value: "test-value"}
|
||||
|
||||
// Store the value to get it in the pool
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
|
||||
// Get the pool value to manipulate its refs directly
|
||||
pool.Lock()
|
||||
upv, exists := pool.pool[key]
|
||||
if !exists {
|
||||
pool.Unlock()
|
||||
t.Fatal("Value should exist in pool")
|
||||
}
|
||||
|
||||
// Manually set refs to 1 to test the panic condition
|
||||
atomic.StoreInt32(&upv.refs, 1)
|
||||
pool.Unlock()
|
||||
|
||||
// Now delete twice - the second delete should cause refs to go negative
|
||||
// First delete
|
||||
deleted1, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Fatalf("First delete failed: %v", err)
|
||||
}
|
||||
if !deleted1 {
|
||||
t.Error("First delete should have removed the value")
|
||||
}
|
||||
|
||||
// Second delete on the same key after it was removed should be safe
|
||||
deleted2, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Errorf("Second delete should not error: %v", err)
|
||||
}
|
||||
if deleted2 {
|
||||
t.Error("Second delete should return false for non-existent key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Range(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
|
||||
// Add multiple values
|
||||
values := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
"key3": "value3",
|
||||
}
|
||||
|
||||
for key, value := range values {
|
||||
pool.LoadOrStore(key, &mockDestructor{value: value})
|
||||
}
|
||||
|
||||
// Range through all values
|
||||
found := make(map[string]string)
|
||||
pool.Range(func(key, value any) bool {
|
||||
found[key.(string)] = value.(*mockDestructor).value
|
||||
return true
|
||||
})
|
||||
|
||||
if len(found) != len(values) {
|
||||
t.Errorf("Expected %d values, got %d", len(values), len(found))
|
||||
}
|
||||
|
||||
for key, expectedValue := range values {
|
||||
if actualValue, exists := found[key]; !exists || actualValue != expectedValue {
|
||||
t.Errorf("Key %s: expected '%s', got '%s'", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Range_EarlyReturn(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
|
||||
// Add multiple values
|
||||
for i := 0; i < 5; i++ {
|
||||
pool.LoadOrStore(i, &mockDestructor{value: "value"})
|
||||
}
|
||||
|
||||
// Range but return false after first iteration
|
||||
count := 0
|
||||
pool.Range(func(key, value any) bool {
|
||||
count++
|
||||
return false // Stop after first iteration
|
||||
})
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 iteration, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Concurrent_LoadOrNew(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "concurrent-key"
|
||||
constructorCalls := int32(0)
|
||||
|
||||
const numGoroutines = 100
|
||||
var wg sync.WaitGroup
|
||||
results := make([]any, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
val, _, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
atomic.AddInt32(&constructorCalls, 1)
|
||||
// Add small delay to increase chance of race conditions
|
||||
time.Sleep(time.Microsecond)
|
||||
return &mockDestructor{value: "concurrent-value"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: Unexpected error: %v", index, err)
|
||||
return
|
||||
}
|
||||
results[index] = val
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Constructor should only be called once
|
||||
if calls := atomic.LoadInt32(&constructorCalls); calls != 1 {
|
||||
t.Errorf("Expected constructor to be called once, was called %d times", calls)
|
||||
}
|
||||
|
||||
// All goroutines should get the same value
|
||||
firstVal := results[0]
|
||||
for i, val := range results {
|
||||
if val != firstVal {
|
||||
t.Errorf("Goroutine %d got different value than first goroutine", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Reference count should equal number of goroutines
|
||||
refs, exists := pool.References(key)
|
||||
if !exists {
|
||||
t.Error("Key should exist in pool")
|
||||
}
|
||||
if refs != numGoroutines {
|
||||
t.Errorf("Expected %d references, got %d", numGoroutines, refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Concurrent_Delete(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "concurrent-delete-key"
|
||||
mockVal := &mockDestructor{value: "test-value"}
|
||||
|
||||
const numRefs = 50
|
||||
|
||||
// Add multiple references
|
||||
for i := 0; i < numRefs; i++ {
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
deleteResults := make([]bool, numRefs)
|
||||
|
||||
// Delete concurrently
|
||||
for i := 0; i < numRefs; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
deleted, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: Unexpected error: %v", index, err)
|
||||
return
|
||||
}
|
||||
deleteResults[index] = deleted
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Exactly one delete should have returned true (when refs reached 0)
|
||||
deletedCount := 0
|
||||
for _, deleted := range deleteResults {
|
||||
if deleted {
|
||||
deletedCount++
|
||||
}
|
||||
}
|
||||
if deletedCount != 1 {
|
||||
t.Errorf("Expected exactly 1 delete to return true, got %d", deletedCount)
|
||||
}
|
||||
|
||||
// Value should be destroyed
|
||||
if !mockVal.IsDestroyed() {
|
||||
t.Error("Value should be destroyed after all references deleted")
|
||||
}
|
||||
|
||||
// Key should not exist
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
t.Error("Key should not exist after all references deleted")
|
||||
}
|
||||
if refs != 0 {
|
||||
t.Errorf("Expected 0 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_DestructorError(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "destructor-error-key"
|
||||
expectedErr := errors.New("destructor failed")
|
||||
mockVal := &mockDestructor{value: "test-value", err: expectedErr}
|
||||
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
|
||||
deleted, err := pool.Delete(key)
|
||||
if err != expectedErr {
|
||||
t.Errorf("Expected destructor error, got: %v", err)
|
||||
}
|
||||
if !deleted {
|
||||
t.Error("Expected deleted to be true even with destructor error")
|
||||
}
|
||||
if !mockVal.IsDestroyed() {
|
||||
t.Error("Destructor should have been called despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Mixed_Concurrent_Operations(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
keys := []string{"key1", "key2", "key3"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const opsPerKey = 10
|
||||
|
||||
// Test concurrent operations but with more controlled behavior
|
||||
for _, key := range keys {
|
||||
for i := 0; i < opsPerKey; i++ {
|
||||
wg.Add(2) // LoadOrStore and Delete
|
||||
|
||||
// LoadOrStore (safer than LoadOrNew for concurrency)
|
||||
go func(k string) {
|
||||
defer wg.Done()
|
||||
pool.LoadOrStore(k, &mockDestructor{value: k + "-value"})
|
||||
}(key)
|
||||
|
||||
// Delete (may fail if refs are 0, that's fine)
|
||||
go func(k string) {
|
||||
defer wg.Done()
|
||||
pool.Delete(k)
|
||||
}(key)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Test that the pool is in a consistent state
|
||||
for _, key := range keys {
|
||||
refs, exists := pool.References(key)
|
||||
if exists && refs < 0 {
|
||||
t.Errorf("Key %s has negative reference count: %d", key, refs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Range_SkipsErrorValues(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
|
||||
// Add value that will succeed
|
||||
goodKey := "good-key"
|
||||
pool.LoadOrStore(goodKey, &mockDestructor{value: "good-value"})
|
||||
|
||||
// Try to add value that will fail construction
|
||||
badKey := "bad-key"
|
||||
pool.LoadOrNew(badKey, func() (Destructor, error) {
|
||||
return nil, errors.New("construction failed")
|
||||
})
|
||||
|
||||
// Range should only iterate good values
|
||||
count := 0
|
||||
pool.Range(func(key, value any) bool {
|
||||
count++
|
||||
if key.(string) != goodKey {
|
||||
t.Errorf("Expected only good key, got: %s", key.(string))
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 value in range, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_LoadOrStore_ErrorRecovery(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "error-recovery-key"
|
||||
|
||||
// First, create a value that fails construction
|
||||
_, _, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
return nil, errors.New("construction failed")
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Expected constructor error")
|
||||
}
|
||||
|
||||
// Now try LoadOrStore with a good value - should recover
|
||||
goodVal := &mockDestructor{value: "recovery-value"}
|
||||
val, loaded := pool.LoadOrStore(key, goodVal)
|
||||
if loaded {
|
||||
t.Error("Expected loaded to be false for error recovery")
|
||||
}
|
||||
if val != goodVal {
|
||||
t.Error("Expected recovery value to be returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_MemoryLeak_Prevention(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "memory-leak-test"
|
||||
|
||||
// Create many references
|
||||
const numRefs = 1000
|
||||
mockVal := &mockDestructor{value: "leak-test"}
|
||||
|
||||
for i := 0; i < numRefs; i++ {
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}
|
||||
|
||||
// Delete all references
|
||||
for i := 0; i < numRefs; i++ {
|
||||
deleted, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete %d: Unexpected error: %v", i, err)
|
||||
}
|
||||
if i == numRefs-1 && !deleted {
|
||||
t.Error("Last delete should return true")
|
||||
} else if i < numRefs-1 && deleted {
|
||||
t.Errorf("Delete %d should return false", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify destructor was called
|
||||
if !mockVal.IsDestroyed() {
|
||||
t.Error("Value should be destroyed after all references deleted")
|
||||
}
|
||||
|
||||
// Verify no memory leak - key should be removed from map
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
t.Error("Key should not exist after complete deletion")
|
||||
}
|
||||
if refs != 0 {
|
||||
t.Errorf("Expected 0 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_RaceCondition_RefsCounter(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "race-test-key"
|
||||
mockVal := &mockDestructor{value: "race-value"}
|
||||
|
||||
const numOperations = 100
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Mix of increment and decrement operations
|
||||
for i := 0; i < numOperations; i++ {
|
||||
wg.Add(2)
|
||||
|
||||
// Increment (LoadOrStore)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}()
|
||||
|
||||
// Decrement (Delete) - may fail if refs are 0, that's ok
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pool.Delete(key)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Final reference count should be consistent
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
if refs < 0 {
|
||||
t.Errorf("Reference count should never be negative, got: %d", refs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUsagePool_LoadOrNew(b *testing.B) {
|
||||
pool := NewUsagePool()
|
||||
key := "bench-key"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
return &mockDestructor{value: "bench-value"}, nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUsagePool_LoadOrStore(b *testing.B) {
|
||||
pool := NewUsagePool()
|
||||
key := "bench-key"
|
||||
mockVal := &mockDestructor{value: "bench-value"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUsagePool_Delete(b *testing.B) {
|
||||
pool := NewUsagePool()
|
||||
key := "bench-key"
|
||||
mockVal := &mockDestructor{value: "bench-value"}
|
||||
|
||||
// Pre-populate with many references
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.Delete(key)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user