mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-03 19:17:29 -05:00 
			
		
		
		
	Merge pull request #699 from eiszfuchs/socket-url
proxy: fix req.URL.Path for unix sockets
This commit is contained in:
		
						commit
						5989eb0635
					
				@ -238,6 +238,116 @@ func TestUnixSocketProxy(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetHTTPProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server) {
 | 
			
		||||
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		fmt.Fprintf(w, messageFormat, r.URL.String())
 | 
			
		||||
	}))
 | 
			
		||||
 | 
			
		||||
	return newPrefixedWebSocketTestProxy(ts.URL, prefix), ts
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetSocketProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server, error) {
 | 
			
		||||
	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		fmt.Fprintf(w, messageFormat, r.URL.String())
 | 
			
		||||
	}))
 | 
			
		||||
 | 
			
		||||
	socketPath, err := filepath.Abs("./test_socket")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("Unable to get absolute path: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ln, err := net.Listen("unix", socketPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("Unable to listen: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	ts.Listener = ln
 | 
			
		||||
 | 
			
		||||
	ts.Start()
 | 
			
		||||
 | 
			
		||||
	tsURL := strings.Replace(ts.URL, "http://", "unix:", 1)
 | 
			
		||||
 | 
			
		||||
	return newPrefixedWebSocketTestProxy(tsURL, prefix), ts, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTestServerMessage(p *Proxy, ts *httptest.Server, path string) (string, error) {
 | 
			
		||||
	echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		p.ServeHTTP(w, r)
 | 
			
		||||
	}))
 | 
			
		||||
 | 
			
		||||
	// *httptest.Server is passed so it can be `defer`red properly
 | 
			
		||||
	defer ts.Close()
 | 
			
		||||
	defer echoProxy.Close()
 | 
			
		||||
 | 
			
		||||
	res, err := http.Get(echoProxy.URL + path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("Unable to GET: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	greeting, err := ioutil.ReadAll(res.Body)
 | 
			
		||||
	res.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("Unable to read body: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("%s", greeting), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUnixSocketProxyPaths(t *testing.T) {
 | 
			
		||||
	greeting := "Hello route %s"
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		url      string
 | 
			
		||||
		prefix   string
 | 
			
		||||
		expected string
 | 
			
		||||
	}{
 | 
			
		||||
		{"", "", fmt.Sprintf(greeting, "/")},
 | 
			
		||||
		{"/hello", "", fmt.Sprintf(greeting, "/hello")},
 | 
			
		||||
		{"/foo/bar", "", fmt.Sprintf(greeting, "/foo/bar")},
 | 
			
		||||
		{"/foo?bar", "", fmt.Sprintf(greeting, "/foo?bar")},
 | 
			
		||||
		{"/greet?name=john", "", fmt.Sprintf(greeting, "/greet?name=john")},
 | 
			
		||||
		{"/world?wonderful&colorful", "", fmt.Sprintf(greeting, "/world?wonderful&colorful")},
 | 
			
		||||
		{"/proxy/hello", "/proxy", fmt.Sprintf(greeting, "/hello")},
 | 
			
		||||
		{"/proxy/foo/bar", "/proxy", fmt.Sprintf(greeting, "/foo/bar")},
 | 
			
		||||
		{"/proxy/?foo=bar", "/proxy", fmt.Sprintf(greeting, "/?foo=bar")},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		p, ts := GetHTTPProxy(greeting, test.prefix)
 | 
			
		||||
 | 
			
		||||
		actualMsg, err := GetTestServerMessage(p, ts, test.url)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Getting server message failed - %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if actualMsg != test.expected {
 | 
			
		||||
			t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if runtime.GOOS == "windows" {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		p, ts, err := GetSocketProxy(greeting, test.prefix)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Getting socket proxy failed - %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		actualMsg, err := GetTestServerMessage(p, ts, test.url)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Getting server message failed - %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if actualMsg != test.expected {
 | 
			
		||||
			t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
 | 
			
		||||
	uri, _ := url.Parse(name)
 | 
			
		||||
	u := &fakeUpstream{
 | 
			
		||||
@ -276,12 +386,19 @@ func (u *fakeUpstream) AllowedPath(requestPath string) bool {
 | 
			
		||||
// proxy.
 | 
			
		||||
func newWebSocketTestProxy(backendAddr string) *Proxy {
 | 
			
		||||
	return &Proxy{
 | 
			
		||||
		Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}},
 | 
			
		||||
		Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
 | 
			
		||||
	return &Proxy{
 | 
			
		||||
		Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type fakeWsUpstream struct {
 | 
			
		||||
	name string
 | 
			
		||||
	name    string
 | 
			
		||||
	without string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (u *fakeWsUpstream) From() string {
 | 
			
		||||
@ -292,7 +409,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost {
 | 
			
		||||
	uri, _ := url.Parse(u.name)
 | 
			
		||||
	return &UpstreamHost{
 | 
			
		||||
		Name:         u.name,
 | 
			
		||||
		ReverseProxy: NewSingleHostReverseProxy(uri, ""),
 | 
			
		||||
		ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
 | 
			
		||||
		ExtraHeaders: http.Header{
 | 
			
		||||
			"Connection": {"{>Connection}"},
 | 
			
		||||
			"Upgrade":    {"{>Upgrade}"}},
 | 
			
		||||
 | 
			
		||||
@ -95,6 +95,18 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
 | 
			
		||||
		} else {
 | 
			
		||||
			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
 | 
			
		||||
		}
 | 
			
		||||
		// Trims the path of the socket from the URL path.
 | 
			
		||||
		// This is done because req.URL passed to your proxied service
 | 
			
		||||
		// will have the full path of the socket file prefixed to it.
 | 
			
		||||
		// Calling /test on a server that proxies requests to
 | 
			
		||||
		// unix:/var/run/www.socket will thus set the requested path
 | 
			
		||||
		// to /var/run/www.socket/test, rendering paths useless.
 | 
			
		||||
		if target.Scheme == "unix" {
 | 
			
		||||
			// See comment on socketDial for the trim
 | 
			
		||||
			socketPrefix := target.String()[len("unix://"):]
 | 
			
		||||
			req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix)
 | 
			
		||||
		}
 | 
			
		||||
		// We are then safe to remove the `without` prefix.
 | 
			
		||||
		if without != "" {
 | 
			
		||||
			req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user