mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-10-25 15:52:45 -04:00 
			
		
		
		
	
						commit
						3faffdce2d
					
				
							
								
								
									
										403
									
								
								middleware/context_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										403
									
								
								middleware/context_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,403 @@ | ||||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| func TestInclude(t *testing.T) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	inputFilename := "test_file" | ||||
| 	absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) | ||||
| 	defer func() { | ||||
| 		err := os.Remove(absInFilePath) | ||||
| 		if err != nil && !os.IsNotExist(err) { | ||||
| 			t.Fatalf("Failed to clean test file!") | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		fileContent          string | ||||
| 		expectedContent      string | ||||
| 		shouldErr            bool | ||||
| 		expectedErrorContent string | ||||
| 	}{ | ||||
| 		// Test 0 - all good | ||||
| 		{ | ||||
| 			fileContent:          `str1 {{ .Root }} str2`, | ||||
| 			expectedContent:      fmt.Sprintf("str1 %s str2", context.Root), | ||||
| 			shouldErr:            false, | ||||
| 			expectedErrorContent: "", | ||||
| 		}, | ||||
| 		// Test 1 - failure on template.Parse | ||||
| 		{ | ||||
| 			fileContent:          `str1 {{ .Root } str2`, | ||||
| 			expectedContent:      "", | ||||
| 			shouldErr:            true, | ||||
| 			expectedErrorContent: `unexpected "}" in operand`, | ||||
| 		}, | ||||
| 		// Test 3 - failure on template.Execute | ||||
| 		{ | ||||
| 			fileContent:          `str1 {{ .InvalidField }} str2`, | ||||
| 			expectedContent:      "", | ||||
| 			shouldErr:            true, | ||||
| 			expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, test := range tests { | ||||
| 		testPrefix := getTestPrefix(i) | ||||
| 
 | ||||
| 		// WriteFile truncates the contentt | ||||
| 		err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		content, err := context.Include(inputFilename) | ||||
| 		if err != nil { | ||||
| 			if !test.shouldErr { | ||||
| 				t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error()) | ||||
| 			} | ||||
| 			if !strings.Contains(err.Error(), test.expectedErrorContent) { | ||||
| 				t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error()) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		if err == nil && test.shouldErr { | ||||
| 			t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename) | ||||
| 		} | ||||
| 
 | ||||
| 		if content != test.expectedContent { | ||||
| 			t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestIncludeNotExisting(t *testing.T) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	_, err := context.Include("not_existing") | ||||
| 	if err == nil { | ||||
| 		t.Errorf("Expected error but found nil!") | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCookie(t *testing.T) { | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		cookie        *http.Cookie | ||||
| 		cookieName    string | ||||
| 		expectedValue string | ||||
| 	}{ | ||||
| 		// Test 0 - happy path | ||||
| 		{ | ||||
| 			cookie:        &http.Cookie{Name: "cookieName", Value: "cookieValue"}, | ||||
| 			cookieName:    "cookieName", | ||||
| 			expectedValue: "cookieValue", | ||||
| 		}, | ||||
| 		// Test 1 - try to get a non-existing cookie | ||||
| 		{ | ||||
| 			cookie:        &http.Cookie{Name: "cookieName", Value: "cookieValue"}, | ||||
| 			cookieName:    "notExisting", | ||||
| 			expectedValue: "", | ||||
| 		}, | ||||
| 		// Test 2 - partial name match | ||||
| 		{ | ||||
| 			cookie:        &http.Cookie{Name: "cookie", Value: "cookieValue"}, | ||||
| 			cookieName:    "cook", | ||||
| 			expectedValue: "", | ||||
| 		}, | ||||
| 		// Test 3 - cookie with optional fields | ||||
| 		{ | ||||
| 			cookie:        &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120}, | ||||
| 			cookieName:    "cookie", | ||||
| 			expectedValue: "cookieValue", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, test := range tests { | ||||
| 		testPrefix := getTestPrefix(i) | ||||
| 
 | ||||
| 		// reinitialize the context for each test | ||||
| 		context := getContextOrFail(t) | ||||
| 
 | ||||
| 		context.Req.AddCookie(test.cookie) | ||||
| 
 | ||||
| 		actualCookieVal := context.Cookie(test.cookieName) | ||||
| 
 | ||||
| 		if actualCookieVal != test.expectedValue { | ||||
| 			t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestCookieMultipleCookies(t *testing.T) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	cookieNameBase, cookieValueBase := "cookieName", "cookieValue" | ||||
| 
 | ||||
| 	// make sure that there's no state and multiple requests for different cookies return the correct result | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)}) | ||||
| 	} | ||||
| 
 | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i) | ||||
| 		actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i)) | ||||
| 		if actualCookieVal != expectedCookieVal { | ||||
| 			t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestHeader(t *testing.T) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	headerKey, headerVal := "Header1", "HeaderVal1" | ||||
| 	context.Req.Header.Add(headerKey, headerVal) | ||||
| 
 | ||||
| 	actualHeaderVal := context.Header(headerKey) | ||||
| 	if actualHeaderVal != headerVal { | ||||
| 		t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) | ||||
| 	} | ||||
| 
 | ||||
| 	missingHeaderVal := context.Header("not-existing") | ||||
| 	if missingHeaderVal != "" { | ||||
| 		t.Errorf("Expected empty header value, found %s", missingHeaderVal) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestIP(t *testing.T) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		inputRemoteAddr string | ||||
| 		expectedIP      string | ||||
| 	}{ | ||||
| 		// Test 0 - ipv4 with port | ||||
| 		{"1.1.1.1:1111", "1.1.1.1"}, | ||||
| 		// Test 1 - ipv4 without port | ||||
| 		{"1.1.1.1", "1.1.1.1"}, | ||||
| 		// Test 2 - ipv6 with port | ||||
| 		{"[::1]:11", "::1"}, | ||||
| 		// Test 3 - ipv6 without port and brackets | ||||
| 		{"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, | ||||
| 		// Test 4 - ipv6 with zone and port | ||||
| 		{`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, test := range tests { | ||||
| 		testPrefix := getTestPrefix(i) | ||||
| 
 | ||||
| 		context.Req.RemoteAddr = test.inputRemoteAddr | ||||
| 		actualIP := context.IP() | ||||
| 
 | ||||
| 		if actualIP != test.expectedIP { | ||||
| 			t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestURL(t *testing.T) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	inputURL := "http://localhost" | ||||
| 	context.Req.RequestURI = inputURL | ||||
| 
 | ||||
| 	if inputURL != context.URI() { | ||||
| 		t.Errorf("Expected url %s, found %s", inputURL, context.URI()) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestHost(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		input        string | ||||
| 		expectedHost string | ||||
| 		shouldErr    bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			input:        "localhost:123", | ||||
| 			expectedHost: "localhost", | ||||
| 			shouldErr:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			input:        "localhost", | ||||
| 			expectedHost: "", | ||||
| 			shouldErr:    true, // missing port in address | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestPort(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		input        string | ||||
| 		expectedPort string | ||||
| 		shouldErr    bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			input:        "localhost:123", | ||||
| 			expectedPort: "123", | ||||
| 			shouldErr:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			input:        "localhost", | ||||
| 			expectedPort: "", | ||||
| 			shouldErr:    true, // missing port in address | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	context.Req.Host = input | ||||
| 	var actualResult, testedObject string | ||||
| 	var err error | ||||
| 
 | ||||
| 	if isTestingHost { | ||||
| 		actualResult, err = context.Host() | ||||
| 		testedObject = "host" | ||||
| 	} else { | ||||
| 		actualResult, err = context.Port() | ||||
| 		testedObject = "port" | ||||
| 	} | ||||
| 
 | ||||
| 	if shouldErr && err == nil { | ||||
| 		t.Errorf("Expected error, found nil!") | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if !shouldErr && err != nil { | ||||
| 		t.Errorf("Expected no error, found %s", err) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if actualResult != expectedResult { | ||||
| 		t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestMethod(t *testing.T) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	method := "POST" | ||||
| 	context.Req.Method = method | ||||
| 
 | ||||
| 	if method != context.Method() { | ||||
| 		t.Errorf("Expected method %s, found %s", method, context.Method()) | ||||
| 	} | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func TestPathMatches(t *testing.T) { | ||||
| 	context := getContextOrFail(t) | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		urlStr      string | ||||
| 		pattern     string | ||||
| 		shouldMatch bool | ||||
| 	}{ | ||||
| 		// Test 0 | ||||
| 		{ | ||||
| 			urlStr:      "http://localhost/", | ||||
| 			pattern:     "", | ||||
| 			shouldMatch: true, | ||||
| 		}, | ||||
| 		// Test 1 | ||||
| 		{ | ||||
| 			urlStr:      "http://localhost", | ||||
| 			pattern:     "", | ||||
| 			shouldMatch: true, | ||||
| 		}, | ||||
| 		// Test 1 | ||||
| 		{ | ||||
| 			urlStr:      "http://localhost/", | ||||
| 			pattern:     "/", | ||||
| 			shouldMatch: true, | ||||
| 		}, | ||||
| 		// Test 3 | ||||
| 		{ | ||||
| 			urlStr:      "http://localhost/?param=val", | ||||
| 			pattern:     "/", | ||||
| 			shouldMatch: true, | ||||
| 		}, | ||||
| 		// Test 4 | ||||
| 		{ | ||||
| 			urlStr:      "http://localhost/dir1/dir2", | ||||
| 			pattern:     "/dir2", | ||||
| 			shouldMatch: false, | ||||
| 		}, | ||||
| 		// Test 5 | ||||
| 		{ | ||||
| 			urlStr:      "http://localhost/dir1/dir2", | ||||
| 			pattern:     "/dir1", | ||||
| 			shouldMatch: true, | ||||
| 		}, | ||||
| 		// Test 6 | ||||
| 		{ | ||||
| 			urlStr:      "http://localhost:444/dir1/dir2", | ||||
| 			pattern:     "/dir1", | ||||
| 			shouldMatch: true, | ||||
| 		}, | ||||
| 		// Test 7 | ||||
| 		{ | ||||
| 			urlStr:      "http://localhost/dir1/dir2", | ||||
| 			pattern:     "*/dir2", | ||||
| 			shouldMatch: false, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, test := range tests { | ||||
| 		testPrefix := getTestPrefix(i) | ||||
| 		var err error | ||||
| 		context.Req.URL, err = url.Parse(test.urlStr) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) | ||||
| 		} | ||||
| 
 | ||||
| 		matches := context.PathMatches(test.pattern) | ||||
| 		if matches != test.shouldMatch { | ||||
| 			t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func initTestContext() (Context, error) { | ||||
| 	body := bytes.NewBufferString("request body") | ||||
| 	request, err := http.NewRequest("GET", "https://localhost", body) | ||||
| 	if err != nil { | ||||
| 		return Context{}, err | ||||
| 	} | ||||
| 
 | ||||
| 	return Context{Root: http.Dir(os.TempDir()), Req: request}, nil | ||||
| } | ||||
| 
 | ||||
| func getContextOrFail(t *testing.T) Context { | ||||
| 	context, err := initTestContext() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to prepare test context") | ||||
| 	} | ||||
| 	return context | ||||
| } | ||||
| 
 | ||||
| func getTestPrefix(testN int) string { | ||||
| 	return fmt.Sprintf("Test [%d]: ", testN) | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user