From 2d6ff40649074c404c76807c3361bd3384b8894e Mon Sep 17 00:00:00 2001 From: Austin Date: Fri, 29 May 2015 19:21:50 -0700 Subject: [PATCH 1/4] add supported for ws in reverse proxy --- config/setup/rewrite_test.go | 3 ++- middleware/proxy/reverseproxy.go | 35 ++++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/config/setup/rewrite_test.go b/config/setup/rewrite_test.go index 9ff294ef0..5747dee30 100644 --- a/config/setup/rewrite_test.go +++ b/config/setup/rewrite_test.go @@ -4,8 +4,9 @@ import ( "testing" "fmt" - "github.com/mholt/caddy/middleware/rewrite" "regexp" + + "github.com/mholt/caddy/middleware/rewrite" ) func TestRewrite(t *testing.T) { diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 027f2266c..1db1131e5 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -16,6 +16,8 @@ import ( "time" ) +const HTTPSwitchProtocols = 101 + // onExitFlushLoop is a callback set by tests to detect the state of the // flushLoop() goroutine. var onExitFlushLoop func() @@ -153,8 +155,37 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr copyHeader(rw.Header(), res.Header) - rw.WriteHeader(res.StatusCode) - p.copyResponse(rw, res.Body) + if res.StatusCode == HTTPSwitchProtocols { + hj, ok := rw.(http.Hijacker) + if !ok { + return nil + } + + conn, _, err := hj.Hijack() + if err != nil { + return err + } + + backendConn, err := net.Dial("tcp", outreq.Host) + if err != nil { + conn.Close() + return err + } + + outreq.Write(backendConn) + + go func() { + io.Copy(backendConn, conn) // write tcp stream to backend. + backendConn.Close() + }() + + io.Copy(conn, backendConn) // read tcp stream from backend. + conn.Close() + } else { + rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) + } + return nil } From 56ec7b98874c13e72807d344ad33c2a553d27b4e Mon Sep 17 00:00:00 2001 From: Austin Date: Sat, 30 May 2015 11:34:54 -0700 Subject: [PATCH 2/4] websocket directive, upgrade comparison --- middleware/proxy/reverseproxy.go | 4 ++-- middleware/proxy/upstream.go | 23 +++++++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 1db1131e5..f3a0390b5 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -16,7 +16,7 @@ import ( "time" ) -const HTTPSwitchProtocols = 101 +const HTTPSwitchingProtocols = 101 // onExitFlushLoop is a callback set by tests to detect the state of the // flushLoop() goroutine. @@ -155,7 +155,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr copyHeader(rw.Header(), res.Header) - if res.StatusCode == HTTPSwitchProtocols { + if res.StatusCode == HTTPSwitchingProtocols && outreq.Header.Get("Upgrade") == "websocket" { hj, ok := rw.(http.Hijacker) if !ok { return nil diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index a657a088e..4c1b9fff7 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -12,7 +12,10 @@ import ( "github.com/mholt/caddy/config/parse" ) -var supportedPolicies map[string]func() Policy = make(map[string]func() Policy) +var ( + supportedPolicies map[string]func() Policy = make(map[string]func() Policy) + proxyHeaders http.Header +) type staticUpstream struct { from string @@ -40,7 +43,7 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { FailTimeout: 10 * time.Second, MaxFails: 1, } - var proxyHeaders http.Header + if !c.Args(&upstream.from) { return upstreams, c.ArgErr() } @@ -97,10 +100,10 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { if !c.Args(&header, &value) { return upstreams, c.ArgErr() } - if proxyHeaders == nil { - proxyHeaders = make(map[string][]string) - } - proxyHeaders.Add(header, value) + addProxyHeader(header, value) + case "websocket": + addProxyHeader("Connection", "{>Connection}") + addProxyHeader("Upgrade", "{>Upgrade}") } } @@ -150,6 +153,14 @@ func RegisterPolicy(name string, policy func() Policy) { supportedPolicies[name] = policy } +// AddProxyHeader adds a proxy header. +func addProxyHeader(header, value string) { + if proxyHeaders == nil { + proxyHeaders = make(map[string][]string) + } + proxyHeaders.Add(header, value) +} + func (u *staticUpstream) From() string { return u.from } From ccd3e55b328bddf85f3b33f870d64d2df676d3cf Mon Sep 17 00:00:00 2001 From: Austin Date: Mon, 1 Jun 2015 10:23:57 -0700 Subject: [PATCH 3/4] changes as noted in PR --- middleware/proxy/reverseproxy.go | 16 +++++++--------- middleware/proxy/upstream.go | 16 ++++------------ 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index f3a0390b5..a49d1080c 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -16,8 +16,6 @@ import ( "time" ) -const HTTPSwitchingProtocols = 101 - // onExitFlushLoop is a callback set by tests to detect the state of the // flushLoop() goroutine. var onExitFlushLoop func() @@ -149,13 +147,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr } defer res.Body.Close() - for _, h := range hopHeaders { - res.Header.Del(h) - } - - copyHeader(rw.Header(), res.Header) - - if res.StatusCode == HTTPSwitchingProtocols && outreq.Header.Get("Upgrade") == "websocket" { + if res.StatusCode == http.StatusSwitchingProtocols && outreq.Header.Get("Upgrade") == "websocket" { hj, ok := rw.(http.Hijacker) if !ok { return nil @@ -182,6 +174,12 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr io.Copy(conn, backendConn) // read tcp stream from backend. conn.Close() } else { + for _, h := range hopHeaders { + res.Header.Del(h) + } + + copyHeader(rw.Header(), res.Header) + rw.WriteHeader(res.StatusCode) p.copyResponse(rw, res.Body) } diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index 4c1b9fff7..011a58b86 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -14,7 +14,7 @@ import ( var ( supportedPolicies map[string]func() Policy = make(map[string]func() Policy) - proxyHeaders http.Header + proxyHeaders http.Header = make(http.Header) ) type staticUpstream struct { @@ -100,10 +100,10 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { if !c.Args(&header, &value) { return upstreams, c.ArgErr() } - addProxyHeader(header, value) + proxyHeaders.Add(header, value) case "websocket": - addProxyHeader("Connection", "{>Connection}") - addProxyHeader("Upgrade", "{>Upgrade}") + proxyHeaders.Add("Connection", "{>Connection}") + proxyHeaders.Add("Upgrade", "{>Upgrade}") } } @@ -153,14 +153,6 @@ func RegisterPolicy(name string, policy func() Policy) { supportedPolicies[name] = policy } -// AddProxyHeader adds a proxy header. -func addProxyHeader(header, value string) { - if proxyHeaders == nil { - proxyHeaders = make(map[string][]string) - } - proxyHeaders.Add(header, value) -} - func (u *staticUpstream) From() string { return u.from } From 68cd4bdeab00e13e2087842c2ed9c6273725b3db Mon Sep 17 00:00:00 2001 From: Austin Date: Mon, 1 Jun 2015 19:29:32 -0700 Subject: [PATCH 4/4] check server response instead of client --- middleware/proxy/reverseproxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index a49d1080c..15350a993 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -147,7 +147,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr } defer res.Body.Close() - if res.StatusCode == http.StatusSwitchingProtocols && outreq.Header.Get("Upgrade") == "websocket" { + if res.StatusCode == http.StatusSwitchingProtocols && res.Header.Get("Upgrade") == "websocket" { hj, ok := rw.(http.Hijacker) if !ok { return nil