Adds failing test for using upstream Host header.
This commit is contained in:
		
							parent
							
								
									ade9502dd2
								
							
						
					
					
						commit
						20a152261c
					
				|  | @ -46,6 +46,10 @@ type OauthProxy struct { | |||
| 	compiledRegex       []*regexp.Regexp | ||||
| } | ||||
| 
 | ||||
| func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) { | ||||
|     return httputil.NewSingleHostReverseProxy(target) | ||||
| } | ||||
| 
 | ||||
| func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | ||||
| 	login, _ := url.Parse("https://accounts.google.com/o/oauth2/auth") | ||||
| 	redeem, _ := url.Parse("https://accounts.google.com/o/oauth2/token") | ||||
|  | @ -54,7 +58,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | |||
| 		path := u.Path | ||||
| 		u.Path = "" | ||||
| 		log.Printf("mapping path %q => upstream %q", path, u) | ||||
| 		serveMux.Handle(path, httputil.NewSingleHostReverseProxy(u)) | ||||
| 		serveMux.Handle(path, NewReverseProxy(u)) | ||||
| 	} | ||||
| 	for _, u := range opts.CompiledRegex { | ||||
| 		log.Printf("compiled skip-auth-regex => %q", u) | ||||
|  |  | |||
|  | @ -0,0 +1,36 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func TestNewReverseProxy(t *testing.T) { | ||||
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(200) | ||||
|         hostname, _, _ := net.SplitHostPort(r.Host) | ||||
| 		w.Write([]byte(hostname)) | ||||
| 	})) | ||||
| 	defer backend.Close() | ||||
| 
 | ||||
| 	backendURL, _ := url.Parse(backend.URL) | ||||
| 	backendHostname := "upstream.127.0.0.1.xip.io" | ||||
| 	_, backendPort, _ := net.SplitHostPort(backendURL.Host) | ||||
| 	backendHost := net.JoinHostPort(backendHostname, backendPort) | ||||
| 	proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") | ||||
| 
 | ||||
| 	proxyHandler := NewReverseProxy(proxyURL) | ||||
| 	frontend := httptest.NewServer(proxyHandler) | ||||
| 	defer frontend.Close() | ||||
| 
 | ||||
| 	getReq, _ := http.NewRequest("GET", frontend.URL, nil) | ||||
| 	res, _ := http.DefaultClient.Do(getReq) | ||||
| 	bodyBytes, _ := ioutil.ReadAll(res.Body) | ||||
| 	if g, e := string(bodyBytes), backendHostname; g != e { | ||||
| 		t.Errorf("got body %q; expected %q", g, e) | ||||
| 	} | ||||
| } | ||||
		Loading…
	
		Reference in New Issue