Merge pull request #593 from oauth2-proxy/proxy-refactor
Integrate upstream package with OAuth2 Proxy
This commit is contained in:
		
						commit
						1aac37d2b1
					
				|  | @ -11,6 +11,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v6.0.0 | ## Changes since v6.0.0 | ||||||
| 
 | 
 | ||||||
|  | - [#593](https://github.com/oauth2-proxy/oauth2-proxy/pull/593) Integrate upstream package with OAuth2 Proxy (@JoelSpeed) | ||||||
| - [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed) | - [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed) | ||||||
| - [#624](https://github.com/oauth2-proxy/oauth2-proxy/pull/624) Allow stripping authentication headers from whitelisted requests with `--skip-auth-strip-headers` (@NickMeves) | - [#624](https://github.com/oauth2-proxy/oauth2-proxy/pull/624) Allow stripping authentication headers from whitelisted requests with `--skip-auth-strip-headers` (@NickMeves) | ||||||
| - [#673](https://github.com/oauth2-proxy/oauth2-proxy/pull/673) Add --session-cookie-minimal option to create session cookies with no tokens (@NickMeves) | - [#673](https://github.com/oauth2-proxy/oauth2-proxy/pull/673) Add --session-cookie-minimal option to create session cookies with no tokens (@NickMeves) | ||||||
|  |  | ||||||
							
								
								
									
										10
									
								
								main.go
								
								
								
								
							
							
						
						
									
										10
									
								
								main.go
								
								
								
								
							|  | @ -32,13 +32,19 @@ func main() { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	opts := options.NewOptions() | 	legacyOpts := options.NewLegacyOptions() | ||||||
| 	err := options.Load(*config, flagSet, opts) | 	err := options.Load(*config, flagSet, legacyOpts) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("ERROR: Failed to load config: %v", err) | 		logger.Printf("ERROR: Failed to load config: %v", err) | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	opts, err := legacyOpts.ToOptions() | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Printf("ERROR: Failed to convert config: %v", err) | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	err = validation.Validate(opts) | 	err = validation.Validate(opts) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("%s", err) | 		logger.Printf("%s", err) | ||||||
|  |  | ||||||
							
								
								
									
										185
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										185
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -2,7 +2,6 @@ package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/tls" |  | ||||||
| 	b64 "encoding/base64" | 	b64 "encoding/base64" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | @ -10,15 +9,12 @@ import ( | ||||||
| 	"html/template" | 	"html/template" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httputil" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"strconv" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/coreos/go-oidc" | 	"github.com/coreos/go-oidc" | ||||||
| 	"github.com/mbland/hmacauth" |  | ||||||
| 	ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" | 	ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | @ -28,37 +24,17 @@ import ( | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/ip" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/ip" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/providers" | 	"github.com/oauth2-proxy/oauth2-proxy/providers" | ||||||
| 	"github.com/yhat/wsutil" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	// SignatureHeader is the name of the request header containing the GAP Signature
 |  | ||||||
| 	// Part of hmacauth
 |  | ||||||
| 	SignatureHeader = "GAP-Signature" |  | ||||||
| 
 |  | ||||||
| 	httpScheme  = "http" | 	httpScheme  = "http" | ||||||
| 	httpsScheme = "https" | 	httpsScheme = "https" | ||||||
| 
 | 
 | ||||||
| 	applicationJSON = "application/json" | 	applicationJSON = "application/json" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // SignatureHeaders contains the headers to be signed by the hmac algorithm
 |  | ||||||
| // Part of hmacauth
 |  | ||||||
| var SignatureHeaders = []string{ |  | ||||||
| 	"Content-Length", |  | ||||||
| 	"Content-Md5", |  | ||||||
| 	"Content-Type", |  | ||||||
| 	"Date", |  | ||||||
| 	"Authorization", |  | ||||||
| 	"X-Forwarded-User", |  | ||||||
| 	"X-Forwarded-Email", |  | ||||||
| 	"X-Forwarded-Preferred-User", |  | ||||||
| 	"X-Forwarded-Access-Token", |  | ||||||
| 	"Cookie", |  | ||||||
| 	"Gap-Auth", |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var ( | var ( | ||||||
| 	// ErrNeedsLogin means the user should be redirected to the login page
 | 	// ErrNeedsLogin means the user should be redirected to the login page
 | ||||||
| 	ErrNeedsLogin = errors.New("redirect to login page") | 	ErrNeedsLogin = errors.New("redirect to login page") | ||||||
|  | @ -124,116 +100,6 @@ type OAuthProxy struct { | ||||||
| 	Footer                  string | 	Footer                  string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // UpstreamProxy represents an upstream server to proxy to
 |  | ||||||
| type UpstreamProxy struct { |  | ||||||
| 	upstream  string |  | ||||||
| 	handler   http.Handler |  | ||||||
| 	wsHandler http.Handler |  | ||||||
| 	auth      hmacauth.HmacAuth |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ServeHTTP proxies requests to the upstream provider while signing the
 |  | ||||||
| // request headers
 |  | ||||||
| func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 	w.Header().Set("GAP-Upstream-Address", u.upstream) |  | ||||||
| 	if u.auth != nil { |  | ||||||
| 		r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) |  | ||||||
| 		u.auth.SignRequest(r) |  | ||||||
| 	} |  | ||||||
| 	if u.wsHandler != nil && strings.EqualFold(r.Header.Get("Connection"), "upgrade") && r.Header.Get("Upgrade") == "websocket" { |  | ||||||
| 		u.wsHandler.ServeHTTP(w, r) |  | ||||||
| 	} else { |  | ||||||
| 		u.handler.ServeHTTP(w, r) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // NewReverseProxy creates a new reverse proxy for proxying requests to upstream
 |  | ||||||
| // servers
 |  | ||||||
| func NewReverseProxy(target *url.URL, opts *options.Options) (proxy *httputil.ReverseProxy) { |  | ||||||
| 	proxy = httputil.NewSingleHostReverseProxy(target) |  | ||||||
| 	proxy.FlushInterval = opts.FlushInterval |  | ||||||
| 	if opts.SSLUpstreamInsecureSkipVerify { |  | ||||||
| 		proxy.Transport = &http.Transport{ |  | ||||||
| 			TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	setProxyErrorHandler(proxy, opts) |  | ||||||
| 	return proxy |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func setProxyErrorHandler(proxy *httputil.ReverseProxy, opts *options.Options) { |  | ||||||
| 	templates := loadTemplates(opts.CustomTemplatesDir) |  | ||||||
| 	proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, proxyErr error) { |  | ||||||
| 		logger.Printf("Error proxying to upstream server: %v", proxyErr) |  | ||||||
| 		w.WriteHeader(http.StatusBadGateway) |  | ||||||
| 		data := struct { |  | ||||||
| 			Title       string |  | ||||||
| 			Message     string |  | ||||||
| 			ProxyPrefix string |  | ||||||
| 		}{ |  | ||||||
| 			Title:       "Bad Gateway", |  | ||||||
| 			Message:     "Error proxying to upstream server", |  | ||||||
| 			ProxyPrefix: opts.ProxyPrefix, |  | ||||||
| 		} |  | ||||||
| 		templates.ExecuteTemplate(w, "error.html", data) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { |  | ||||||
| 	director := proxy.Director |  | ||||||
| 	proxy.Director = func(req *http.Request) { |  | ||||||
| 		director(req) |  | ||||||
| 		// use RequestURI so that we aren't unescaping encoded slashes in the request path
 |  | ||||||
| 		req.Host = target.Host |  | ||||||
| 		req.URL.Opaque = req.RequestURI |  | ||||||
| 		req.URL.RawQuery = "" |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func setProxyDirector(proxy *httputil.ReverseProxy) { |  | ||||||
| 	director := proxy.Director |  | ||||||
| 	proxy.Director = func(req *http.Request) { |  | ||||||
| 		director(req) |  | ||||||
| 		// use RequestURI so that we aren't unescaping encoded slashes in the request path
 |  | ||||||
| 		req.URL.Opaque = req.RequestURI |  | ||||||
| 		req.URL.RawQuery = "" |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // NewFileServer creates a http.Handler to serve files from the filesystem
 |  | ||||||
| func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { |  | ||||||
| 	return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // NewWebSocketOrRestReverseProxy creates a reverse proxy for REST or websocket based on url
 |  | ||||||
| func NewWebSocketOrRestReverseProxy(u *url.URL, opts *options.Options, auth hmacauth.HmacAuth) http.Handler { |  | ||||||
| 	u.Path = "" |  | ||||||
| 	proxy := NewReverseProxy(u, opts) |  | ||||||
| 	if !opts.PassHostHeader { |  | ||||||
| 		setProxyUpstreamHostHeader(proxy, u) |  | ||||||
| 	} else { |  | ||||||
| 		setProxyDirector(proxy) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// this should give us a wss:// scheme if the url is https:// based.
 |  | ||||||
| 	var wsProxy *wsutil.ReverseProxy |  | ||||||
| 	if opts.ProxyWebSockets { |  | ||||||
| 		wsScheme := "ws" + strings.TrimPrefix(u.Scheme, "http") |  | ||||||
| 		wsURL := &url.URL{Scheme: wsScheme, Host: u.Host} |  | ||||||
| 		wsProxy = wsutil.NewSingleHostReverseProxy(wsURL) |  | ||||||
| 		if opts.SSLUpstreamInsecureSkipVerify { |  | ||||||
| 			wsProxy.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return &UpstreamProxy{ |  | ||||||
| 		upstream:  u.Host, |  | ||||||
| 		handler:   proxy, |  | ||||||
| 		wsHandler: wsProxy, |  | ||||||
| 		auth:      auth, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | ||||||
| func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { | func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { | ||||||
| 	sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie) | 	sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie) | ||||||
|  | @ -241,48 +107,13 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		return nil, fmt.Errorf("error initialising session store: %v", err) | 		return nil, fmt.Errorf("error initialising session store: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	serveMux := http.NewServeMux() | 	templates := loadTemplates(opts.CustomTemplatesDir) | ||||||
| 	var auth hmacauth.HmacAuth | 	proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) | ||||||
| 	if sigData := opts.GetSignatureData(); sigData != nil { | 	upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) | ||||||
| 		auth = hmacauth.NewHmacAuth(sigData.Hash, []byte(sigData.Key), | 	if err != nil { | ||||||
| 			SignatureHeader, SignatureHeaders) | 		return nil, fmt.Errorf("error initialising upstream proxy: %v", err) | ||||||
| 	} | 	} | ||||||
| 	for _, u := range opts.GetProxyURLs() { |  | ||||||
| 		path := u.Path |  | ||||||
| 		host := u.Host |  | ||||||
| 		switch u.Scheme { |  | ||||||
| 		case httpScheme, httpsScheme: |  | ||||||
| 			logger.Printf("mapping path %q => upstream %q", path, u) |  | ||||||
| 			proxy := NewWebSocketOrRestReverseProxy(u, opts, auth) |  | ||||||
| 			serveMux.Handle(path, proxy) |  | ||||||
| 		case "static": |  | ||||||
| 			responseCode, err := strconv.Atoi(host) |  | ||||||
| 			if err != nil { |  | ||||||
| 				logger.Printf("unable to convert %q to int, use default \"200\"", host) |  | ||||||
| 				responseCode = 200 |  | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			serveMux.HandleFunc(path, func(rw http.ResponseWriter, req *http.Request) { |  | ||||||
| 				rw.WriteHeader(responseCode) |  | ||||||
| 				fmt.Fprintf(rw, "Authenticated") |  | ||||||
| 			}) |  | ||||||
| 		case "file": |  | ||||||
| 			if u.Fragment != "" { |  | ||||||
| 				path = u.Fragment |  | ||||||
| 			} |  | ||||||
| 			logger.Printf("mapping path %q => file system %q", path, u.Path) |  | ||||||
| 			proxy := NewFileServer(path, u.Path) |  | ||||||
| 			uProxy := UpstreamProxy{ |  | ||||||
| 				upstream:  path, |  | ||||||
| 				handler:   proxy, |  | ||||||
| 				wsHandler: nil, |  | ||||||
| 				auth:      nil, |  | ||||||
| 			} |  | ||||||
| 			serveMux.Handle(path, &uProxy) |  | ||||||
| 		default: |  | ||||||
| 			panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	for _, u := range opts.GetCompiledRegex() { | 	for _, u := range opts.GetCompiledRegex() { | ||||||
| 		logger.Printf("compiled skip-auth-regex => %q", u) | 		logger.Printf("compiled skip-auth-regex => %q", u) | ||||||
| 	} | 	} | ||||||
|  | @ -350,7 +181,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		provider:                opts.GetProvider(), | 		provider:                opts.GetProvider(), | ||||||
| 		providerNameOverride:    opts.ProviderName, | 		providerNameOverride:    opts.ProviderName, | ||||||
| 		sessionStore:            sessionStore, | 		sessionStore:            sessionStore, | ||||||
| 		serveMux:                serveMux, | 		serveMux:                upstreamProxy, | ||||||
| 		redirectURL:             redirectURL, | 		redirectURL:             redirectURL, | ||||||
| 		whitelistDomains:        opts.WhitelistDomains, | 		whitelistDomains:        opts.WhitelistDomains, | ||||||
| 		skipAuthRegex:           opts.SkipAuthRegex, | 		skipAuthRegex:           opts.SkipAuthRegex, | ||||||
|  | @ -371,7 +202,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		PassAuthorization:       opts.PassAuthorization, | 		PassAuthorization:       opts.PassAuthorization, | ||||||
| 		PreferEmailToUser:       opts.PreferEmailToUser, | 		PreferEmailToUser:       opts.PreferEmailToUser, | ||||||
| 		SkipProviderButton:      opts.SkipProviderButton, | 		SkipProviderButton:      opts.SkipProviderButton, | ||||||
| 		templates:               loadTemplates(opts.CustomTemplatesDir), | 		templates:               templates, | ||||||
| 		trustedIPs:              trustedIPs, | 		trustedIPs:              trustedIPs, | ||||||
| 		Banner:                  opts.Banner, | 		Banner:                  opts.Banner, | ||||||
| 		Footer:                  opts.Footer, | 		Footer:                  opts.Footer, | ||||||
|  |  | ||||||
|  | @ -8,7 +8,6 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | @ -24,11 +23,11 @@ import ( | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
| 	sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" | 	sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/providers" | 	"github.com/oauth2-proxy/oauth2-proxy/providers" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
| 	"golang.org/x/net/websocket" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | @ -44,143 +43,6 @@ func init() { | ||||||
| 	logger.SetFlags(logger.Lshortfile) | 	logger.SetFlags(logger.Lshortfile) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type WebSocketOrRestHandler struct { |  | ||||||
| 	restHandler http.Handler |  | ||||||
| 	wsHandler   http.Handler |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 	if r.Header.Get("Upgrade") == "websocket" { |  | ||||||
| 		h.wsHandler.ServeHTTP(w, r) |  | ||||||
| 	} else { |  | ||||||
| 		h.restHandler.ServeHTTP(w, r) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestWebSocketProxy(t *testing.T) { |  | ||||||
| 	handler := WebSocketOrRestHandler{ |  | ||||||
| 		restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 			w.WriteHeader(200) |  | ||||||
| 			hostname, _, _ := net.SplitHostPort(r.Host) |  | ||||||
| 			_, err := w.Write([]byte(hostname)) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatal(err) |  | ||||||
| 			} |  | ||||||
| 		}), |  | ||||||
| 		wsHandler: websocket.Handler(func(ws *websocket.Conn) { |  | ||||||
| 			defer func(t *testing.T) { |  | ||||||
| 				if err := ws.Close(); err != nil { |  | ||||||
| 					t.Fatal(err) |  | ||||||
| 				} |  | ||||||
| 			}(t) |  | ||||||
| 			var data []byte |  | ||||||
| 			err := websocket.Message.Receive(ws, &data) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatal(err) |  | ||||||
| 			} |  | ||||||
| 			err = websocket.Message.Send(ws, data) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatal(err) |  | ||||||
| 			} |  | ||||||
| 		}), |  | ||||||
| 	} |  | ||||||
| 	backend := httptest.NewServer(&handler) |  | ||||||
| 	t.Cleanup(backend.Close) |  | ||||||
| 
 |  | ||||||
| 	backendURL, _ := url.Parse(backend.URL) |  | ||||||
| 
 |  | ||||||
| 	opts := baseTestOptions() |  | ||||||
| 	var auth hmacauth.HmacAuth |  | ||||||
| 	opts.PassHostHeader = true |  | ||||||
| 	proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth) |  | ||||||
| 	frontend := httptest.NewServer(proxyHandler) |  | ||||||
| 	t.Cleanup(frontend.Close) |  | ||||||
| 
 |  | ||||||
| 	frontendURL, _ := url.Parse(frontend.URL) |  | ||||||
| 	frontendWSURL := "ws://" + frontendURL.Host + "/" |  | ||||||
| 
 |  | ||||||
| 	ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/") |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	request := []byte("hello, world!") |  | ||||||
| 	err = websocket.Message.Send(ws, request) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	var response = make([]byte, 1024) |  | ||||||
| 	err = websocket.Message.Receive(ws, &response) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	if g, e := string(request), string(response); g != e { |  | ||||||
| 		t.Errorf("got body %q; expected %q", g, e) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	getReq, _ := http.NewRequest("GET", frontend.URL, nil) |  | ||||||
| 	res, _ := http.DefaultClient.Do(getReq) |  | ||||||
| 	bodyBytes, _ := ioutil.ReadAll(res.Body) |  | ||||||
| 	backendHostname, _, _ := net.SplitHostPort(backendURL.Host) |  | ||||||
| 	if g, e := string(bodyBytes), backendHostname; g != e { |  | ||||||
| 		t.Errorf("got body %q; expected %q", g, e) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestNewReverseProxy(t *testing.T) { |  | ||||||
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 		w.WriteHeader(200) |  | ||||||
| 		hostname, _, _ := net.SplitHostPort(r.Host) |  | ||||||
| 		_, err := w.Write([]byte(hostname)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatal(err) |  | ||||||
| 		} |  | ||||||
| 	})) |  | ||||||
| 	t.Cleanup(backend.Close) |  | ||||||
| 
 |  | ||||||
| 	backendURL, _ := url.Parse(backend.URL) |  | ||||||
| 	backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) |  | ||||||
| 	backendHost := net.JoinHostPort(backendHostname, backendPort) |  | ||||||
| 	proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") |  | ||||||
| 
 |  | ||||||
| 	proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second}) |  | ||||||
| 	setProxyUpstreamHostHeader(proxyHandler, proxyURL) |  | ||||||
| 	frontend := httptest.NewServer(proxyHandler) |  | ||||||
| 	t.Cleanup(frontend.Close) |  | ||||||
| 
 |  | ||||||
| 	getReq, _ := http.NewRequest("GET", frontend.URL, nil) |  | ||||||
| 	res, _ := http.DefaultClient.Do(getReq) |  | ||||||
| 	bodyBytes, _ := ioutil.ReadAll(res.Body) |  | ||||||
| 	if g, e := string(bodyBytes), backendHostname; g != e { |  | ||||||
| 		t.Errorf("got body %q; expected %q", g, e) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestEncodedSlashes(t *testing.T) { |  | ||||||
| 	var seen string |  | ||||||
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 		w.WriteHeader(200) |  | ||||||
| 		seen = r.RequestURI |  | ||||||
| 	})) |  | ||||||
| 	t.Cleanup(backend.Close) |  | ||||||
| 
 |  | ||||||
| 	b, _ := url.Parse(backend.URL) |  | ||||||
| 	proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second}) |  | ||||||
| 	setProxyDirector(proxyHandler) |  | ||||||
| 	frontend := httptest.NewServer(proxyHandler) |  | ||||||
| 	t.Cleanup(frontend.Close) |  | ||||||
| 
 |  | ||||||
| 	f, _ := url.Parse(frontend.URL) |  | ||||||
| 	encodedPath := "/a%2Fb/?c=1" |  | ||||||
| 	getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}} |  | ||||||
| 	_, err := http.DefaultClient.Do(getReq) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	if seen != encodedPath { |  | ||||||
| 		t.Errorf("got bad request %q expected %q", seen, encodedPath) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestRobotsTxt(t *testing.T) { | func TestRobotsTxt(t *testing.T) { | ||||||
| 	opts := baseTestOptions() | 	opts := baseTestOptions() | ||||||
| 	err := validation.Validate(opts) | 	err := validation.Validate(opts) | ||||||
|  | @ -562,7 +424,14 @@ func TestBasicAuthPassword(t *testing.T) { | ||||||
| 		} | 		} | ||||||
| 	})) | 	})) | ||||||
| 	opts := baseTestOptions() | 	opts := baseTestOptions() | ||||||
| 	opts.Upstreams = append(opts.Upstreams, providerServer.URL) | 	opts.UpstreamServers = options.Upstreams{ | ||||||
|  | 		{ | ||||||
|  | 			ID:   providerServer.URL, | ||||||
|  | 			Path: "/", | ||||||
|  | 			URI:  providerServer.URL, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	opts.Cookie.Secure = false | 	opts.Cookie.Secure = false | ||||||
| 	opts.PassBasicAuth = true | 	opts.PassBasicAuth = true | ||||||
| 	opts.SetBasicAuth = true | 	opts.SetBasicAuth = true | ||||||
|  | @ -867,7 +736,7 @@ type PassAccessTokenTest struct { | ||||||
| 
 | 
 | ||||||
| type PassAccessTokenTestOptions struct { | type PassAccessTokenTestOptions struct { | ||||||
| 	PassAccessToken bool | 	PassAccessToken bool | ||||||
| 	ProxyUpstream   string | 	ProxyUpstream   options.Upstream | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) { | func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) { | ||||||
|  | @ -893,10 +762,17 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe | ||||||
| 		})) | 		})) | ||||||
| 
 | 
 | ||||||
| 	patt.opts = baseTestOptions() | 	patt.opts = baseTestOptions() | ||||||
| 	patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL) | 	patt.opts.UpstreamServers = options.Upstreams{ | ||||||
| 	if opts.ProxyUpstream != "" { | 		{ | ||||||
| 		patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream) | 			ID:   patt.providerServer.URL, | ||||||
|  | 			Path: "/", | ||||||
|  | 			URI:  patt.providerServer.URL, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  | 	if opts.ProxyUpstream.ID != "" { | ||||||
|  | 		patt.opts.UpstreamServers = append(patt.opts.UpstreamServers, opts.ProxyUpstream) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	patt.opts.Cookie.Secure = false | 	patt.opts.Cookie.Secure = false | ||||||
| 	patt.opts.PassAccessToken = opts.PassAccessToken | 	patt.opts.PassAccessToken = opts.PassAccessToken | ||||||
| 	err := validation.Validate(patt.opts) | 	err := validation.Validate(patt.opts) | ||||||
|  | @ -999,7 +875,11 @@ func TestForwardAccessTokenUpstream(t *testing.T) { | ||||||
| func TestStaticProxyUpstream(t *testing.T) { | func TestStaticProxyUpstream(t *testing.T) { | ||||||
| 	patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | 	patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | ||||||
| 		PassAccessToken: true, | 		PassAccessToken: true, | ||||||
| 		ProxyUpstream:   "static://200/static-proxy", | 		ProxyUpstream: options.Upstream{ | ||||||
|  | 			ID:     "static-proxy", | ||||||
|  | 			Path:   "/static-proxy", | ||||||
|  | 			Static: true, | ||||||
|  | 		}, | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
|  | @ -1572,7 +1452,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { | ||||||
| 	t.Cleanup(upstream.Close) | 	t.Cleanup(upstream.Close) | ||||||
| 
 | 
 | ||||||
| 	opts := baseTestOptions() | 	opts := baseTestOptions() | ||||||
| 	opts.Upstreams = append(opts.Upstreams, upstream.URL) | 	opts.UpstreamServers = options.Upstreams{ | ||||||
|  | 		{ | ||||||
|  | 			ID:   upstream.URL, | ||||||
|  | 			Path: "/", | ||||||
|  | 			URI:  upstream.URL, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
| 	opts.SkipAuthPreflight = true | 	opts.SkipAuthPreflight = true | ||||||
| 	err := validation.Validate(opts) | 	err := validation.Validate(opts) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
|  | @ -1641,7 +1527,13 @@ func NewSignatureTest() (*SignatureTest, error) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	opts.Upstreams = append(opts.Upstreams, upstream.URL) | 	opts.UpstreamServers = options.Upstreams{ | ||||||
|  | 		{ | ||||||
|  | 			ID:   upstream.URL, | ||||||
|  | 			Path: "/", | ||||||
|  | 			URI:  upstream.URL, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	providerHandler := func(w http.ResponseWriter, r *http.Request) { | 	providerHandler := func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`)) | 		_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`)) | ||||||
|  | @ -1716,7 +1608,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er | ||||||
| 	} | 	} | ||||||
| 	// This is used by the upstream to validate the signature.
 | 	// This is used by the upstream to validate the signature.
 | ||||||
| 	st.authenticator.auth = hmacauth.NewHmacAuth( | 	st.authenticator.auth = hmacauth.NewHmacAuth( | ||||||
| 		crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) | 		crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders) | ||||||
| 	proxy.ServeHTTP(st.rw, req) | 	proxy.ServeHTTP(st.rw, req) | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -2110,7 +2002,13 @@ func Test_noCacheHeaders(t *testing.T) { | ||||||
| 	t.Cleanup(upstream.Close) | 	t.Cleanup(upstream.Close) | ||||||
| 
 | 
 | ||||||
| 	opts := baseTestOptions() | 	opts := baseTestOptions() | ||||||
| 	opts.Upstreams = []string{upstream.URL} | 	opts.UpstreamServers = options.Upstreams{ | ||||||
|  | 		{ | ||||||
|  | 			ID:   upstream.URL, | ||||||
|  | 			Path: "/", | ||||||
|  | 			URI:  upstream.URL, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
| 	opts.SkipAuthRegex = []string{".*"} | 	opts.SkipAuthRegex = []string{".*"} | ||||||
| 	err := validation.Validate(opts) | 	err := validation.Validate(opts) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
|  | @ -2335,7 +2233,13 @@ func TestTrustedIPs(t *testing.T) { | ||||||
| 	for _, tt := range tests { | 	for _, tt := range tests { | ||||||
| 		t.Run(tt.name, func(t *testing.T) { | 		t.Run(tt.name, func(t *testing.T) { | ||||||
| 			opts := baseTestOptions() | 			opts := baseTestOptions() | ||||||
| 			opts.Upstreams = []string{"static://200"} | 			opts.UpstreamServers = options.Upstreams{ | ||||||
|  | 				{ | ||||||
|  | 					ID:     "static", | ||||||
|  | 					Path:   "/", | ||||||
|  | 					Static: true, | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
| 			opts.TrustedIPs = tt.trustedIPs | 			opts.TrustedIPs = tt.trustedIPs | ||||||
| 			opts.ReverseProxy = tt.reverseProxy | 			opts.ReverseProxy = tt.reverseProxy | ||||||
| 			opts.RealClientIPHeader = tt.realClientIPHeader | 			opts.RealClientIPHeader = tt.realClientIPHeader | ||||||
|  |  | ||||||
|  | @ -0,0 +1,117 @@ | ||||||
|  | package options | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/url" | ||||||
|  | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | 	"github.com/spf13/pflag" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type LegacyOptions struct { | ||||||
|  | 	// Legacy options related to upstream servers
 | ||||||
|  | 	LegacyUpstreams LegacyUpstreams `cfg:",squash"` | ||||||
|  | 
 | ||||||
|  | 	Options Options `cfg:",squash"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewLegacyOptions() *LegacyOptions { | ||||||
|  | 	return &LegacyOptions{ | ||||||
|  | 		LegacyUpstreams: LegacyUpstreams{ | ||||||
|  | 			PassHostHeader:  true, | ||||||
|  | 			ProxyWebSockets: true, | ||||||
|  | 			FlushInterval:   time.Duration(1) * time.Second, | ||||||
|  | 		}, | ||||||
|  | 
 | ||||||
|  | 		Options: *NewOptions(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LegacyOptions) ToOptions() (*Options, error) { | ||||||
|  | 	upstreams, err := l.LegacyUpstreams.convert() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error converting upstreams: %v", err) | ||||||
|  | 	} | ||||||
|  | 	l.Options.UpstreamServers = upstreams | ||||||
|  | 
 | ||||||
|  | 	return &l.Options, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type LegacyUpstreams struct { | ||||||
|  | 	FlushInterval                 time.Duration `flag:"flush-interval" cfg:"flush_interval"` | ||||||
|  | 	PassHostHeader                bool          `flag:"pass-host-header" cfg:"pass_host_header"` | ||||||
|  | 	ProxyWebSockets               bool          `flag:"proxy-websockets" cfg:"proxy_websockets"` | ||||||
|  | 	SSLUpstreamInsecureSkipVerify bool          `flag:"ssl-upstream-insecure-skip-verify" cfg:"ssl_upstream_insecure_skip_verify"` | ||||||
|  | 	Upstreams                     []string      `flag:"upstream" cfg:"upstreams"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func legacyUpstreamsFlagSet() *pflag.FlagSet { | ||||||
|  | 	flagSet := pflag.NewFlagSet("upstreams", pflag.ExitOnError) | ||||||
|  | 
 | ||||||
|  | 	flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses") | ||||||
|  | 	flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") | ||||||
|  | 	flagSet.Bool("proxy-websockets", true, "enables WebSocket proxying") | ||||||
|  | 	flagSet.Bool("ssl-upstream-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS upstreams") | ||||||
|  | 	flagSet.StringSlice("upstream", []string{}, "the http url(s) of the upstream endpoint, file:// paths for static files or static://<status_code> for static response. Routing is based on the path") | ||||||
|  | 
 | ||||||
|  | 	return flagSet | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l *LegacyUpstreams) convert() (Upstreams, error) { | ||||||
|  | 	upstreams := Upstreams{} | ||||||
|  | 
 | ||||||
|  | 	for _, upstreamString := range l.Upstreams { | ||||||
|  | 		u, err := url.Parse(upstreamString) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("could not parse upstream %q: %v", upstreamString, err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if u.Path == "" { | ||||||
|  | 			u.Path = "/" | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		upstream := Upstream{ | ||||||
|  | 			ID:                    u.Path, | ||||||
|  | 			Path:                  u.Path, | ||||||
|  | 			URI:                   upstreamString, | ||||||
|  | 			InsecureSkipTLSVerify: l.SSLUpstreamInsecureSkipVerify, | ||||||
|  | 			PassHostHeader:        &l.PassHostHeader, | ||||||
|  | 			ProxyWebSockets:       &l.ProxyWebSockets, | ||||||
|  | 			FlushInterval:         &l.FlushInterval, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		switch u.Scheme { | ||||||
|  | 		case "file": | ||||||
|  | 			if u.Fragment != "" { | ||||||
|  | 				upstream.ID = u.Fragment | ||||||
|  | 				upstream.Path = u.Fragment | ||||||
|  | 			} | ||||||
|  | 		case "static": | ||||||
|  | 			responseCode, err := strconv.Atoi(u.Host) | ||||||
|  | 			if err != nil { | ||||||
|  | 				logger.Printf("unable to convert %q to int, use default \"200\"", u.Host) | ||||||
|  | 				responseCode = 200 | ||||||
|  | 			} | ||||||
|  | 			upstream.Static = true | ||||||
|  | 			upstream.StaticCode = &responseCode | ||||||
|  | 
 | ||||||
|  | 			// These are not allowed to be empty and must be unique
 | ||||||
|  | 			upstream.ID = upstreamString | ||||||
|  | 			upstream.Path = upstreamString | ||||||
|  | 
 | ||||||
|  | 			// Force defaults compatible with static responses
 | ||||||
|  | 			upstream.URI = "" | ||||||
|  | 			upstream.InsecureSkipTLSVerify = false | ||||||
|  | 			upstream.PassHostHeader = nil | ||||||
|  | 			upstream.ProxyWebSockets = nil | ||||||
|  | 			flush := 1 * time.Second | ||||||
|  | 			upstream.FlushInterval = &flush | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		upstreams = append(upstreams, upstream) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return upstreams, nil | ||||||
|  | } | ||||||
|  | @ -0,0 +1,203 @@ | ||||||
|  | package options | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Legacy Options", func() { | ||||||
|  | 	Context("ToOptions", func() { | ||||||
|  | 		It("converts the options as expected", func() { | ||||||
|  | 			opts := NewOptions() | ||||||
|  | 
 | ||||||
|  | 			legacyOpts := NewLegacyOptions() | ||||||
|  | 
 | ||||||
|  | 			// Set upstreams and related options to test their conversion
 | ||||||
|  | 			flushInterval := 5 * time.Second | ||||||
|  | 			legacyOpts.LegacyUpstreams.FlushInterval = flushInterval | ||||||
|  | 			legacyOpts.LegacyUpstreams.PassHostHeader = true | ||||||
|  | 			legacyOpts.LegacyUpstreams.ProxyWebSockets = true | ||||||
|  | 			legacyOpts.LegacyUpstreams.SSLUpstreamInsecureSkipVerify = true | ||||||
|  | 			legacyOpts.LegacyUpstreams.Upstreams = []string{"http://foo.bar/baz", "file://var/lib/website#/bar"} | ||||||
|  | 
 | ||||||
|  | 			truth := true | ||||||
|  | 			opts.UpstreamServers = Upstreams{ | ||||||
|  | 				{ | ||||||
|  | 					ID:                    "/baz", | ||||||
|  | 					Path:                  "/baz", | ||||||
|  | 					URI:                   "http://foo.bar/baz", | ||||||
|  | 					FlushInterval:         &flushInterval, | ||||||
|  | 					InsecureSkipTLSVerify: true, | ||||||
|  | 					PassHostHeader:        &truth, | ||||||
|  | 					ProxyWebSockets:       &truth, | ||||||
|  | 				}, | ||||||
|  | 				{ | ||||||
|  | 					ID:                    "/bar", | ||||||
|  | 					Path:                  "/bar", | ||||||
|  | 					URI:                   "file://var/lib/website#/bar", | ||||||
|  | 					FlushInterval:         &flushInterval, | ||||||
|  | 					InsecureSkipTLSVerify: true, | ||||||
|  | 					PassHostHeader:        &truth, | ||||||
|  | 					ProxyWebSockets:       &truth, | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			converted, err := legacyOpts.ToOptions() | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			Expect(converted).To(Equal(opts)) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("Legacy Upstreams", func() { | ||||||
|  | 		type convertUpstreamsTableInput struct { | ||||||
|  | 			upstreamStrings   []string | ||||||
|  | 			expectedUpstreams Upstreams | ||||||
|  | 			errMsg            string | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		defaultFlushInterval := 1 * time.Second | ||||||
|  | 
 | ||||||
|  | 		// Non defaults for these options
 | ||||||
|  | 		skipVerify := true | ||||||
|  | 		passHostHeader := false | ||||||
|  | 		proxyWebSockets := true | ||||||
|  | 		flushInterval := 5 * time.Second | ||||||
|  | 
 | ||||||
|  | 		// Test cases and expected outcomes
 | ||||||
|  | 		validHTTP := "http://foo.bar/baz" | ||||||
|  | 		validHTTPUpstream := Upstream{ | ||||||
|  | 			ID:                    "/baz", | ||||||
|  | 			Path:                  "/baz", | ||||||
|  | 			URI:                   validHTTP, | ||||||
|  | 			InsecureSkipTLSVerify: skipVerify, | ||||||
|  | 			PassHostHeader:        &passHostHeader, | ||||||
|  | 			ProxyWebSockets:       &proxyWebSockets, | ||||||
|  | 			FlushInterval:         &flushInterval, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// Test cases and expected outcomes
 | ||||||
|  | 		emptyPathHTTP := "http://foo.bar" | ||||||
|  | 		emptyPathHTTPUpstream := Upstream{ | ||||||
|  | 			ID:                    "/", | ||||||
|  | 			Path:                  "/", | ||||||
|  | 			URI:                   emptyPathHTTP, | ||||||
|  | 			InsecureSkipTLSVerify: skipVerify, | ||||||
|  | 			PassHostHeader:        &passHostHeader, | ||||||
|  | 			ProxyWebSockets:       &proxyWebSockets, | ||||||
|  | 			FlushInterval:         &flushInterval, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		validFileWithFragment := "file://var/lib/website#/bar" | ||||||
|  | 		validFileWithFragmentUpstream := Upstream{ | ||||||
|  | 			ID:                    "/bar", | ||||||
|  | 			Path:                  "/bar", | ||||||
|  | 			URI:                   validFileWithFragment, | ||||||
|  | 			InsecureSkipTLSVerify: skipVerify, | ||||||
|  | 			PassHostHeader:        &passHostHeader, | ||||||
|  | 			ProxyWebSockets:       &proxyWebSockets, | ||||||
|  | 			FlushInterval:         &flushInterval, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		validStatic := "static://204" | ||||||
|  | 		validStaticCode := 204 | ||||||
|  | 		validStaticUpstream := Upstream{ | ||||||
|  | 			ID:                    validStatic, | ||||||
|  | 			Path:                  validStatic, | ||||||
|  | 			URI:                   "", | ||||||
|  | 			Static:                true, | ||||||
|  | 			StaticCode:            &validStaticCode, | ||||||
|  | 			InsecureSkipTLSVerify: false, | ||||||
|  | 			PassHostHeader:        nil, | ||||||
|  | 			ProxyWebSockets:       nil, | ||||||
|  | 			FlushInterval:         &defaultFlushInterval, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		invalidStatic := "static://abc" | ||||||
|  | 		invalidStaticCode := 200 | ||||||
|  | 		invalidStaticUpstream := Upstream{ | ||||||
|  | 			ID:                    invalidStatic, | ||||||
|  | 			Path:                  invalidStatic, | ||||||
|  | 			URI:                   "", | ||||||
|  | 			Static:                true, | ||||||
|  | 			StaticCode:            &invalidStaticCode, | ||||||
|  | 			InsecureSkipTLSVerify: false, | ||||||
|  | 			PassHostHeader:        nil, | ||||||
|  | 			ProxyWebSockets:       nil, | ||||||
|  | 			FlushInterval:         &defaultFlushInterval, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		invalidHTTP := ":foo" | ||||||
|  | 		invalidHTTPErrMsg := "could not parse upstream \":foo\": parse \":foo\": missing protocol scheme" | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("convertLegacyUpstreams", | ||||||
|  | 			func(o *convertUpstreamsTableInput) { | ||||||
|  | 				legacyUpstreams := LegacyUpstreams{ | ||||||
|  | 					Upstreams:                     o.upstreamStrings, | ||||||
|  | 					SSLUpstreamInsecureSkipVerify: skipVerify, | ||||||
|  | 					PassHostHeader:                passHostHeader, | ||||||
|  | 					ProxyWebSockets:               proxyWebSockets, | ||||||
|  | 					FlushInterval:                 flushInterval, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				upstreams, err := legacyUpstreams.convert() | ||||||
|  | 
 | ||||||
|  | 				if o.errMsg != "" { | ||||||
|  | 					Expect(err).To(HaveOccurred()) | ||||||
|  | 					Expect(err.Error()).To(Equal(o.errMsg)) | ||||||
|  | 				} else { | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				Expect(upstreams).To(ConsistOf(o.expectedUpstreams)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("with no upstreams", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{}, | ||||||
|  | 				expectedUpstreams: Upstreams{}, | ||||||
|  | 				errMsg:            "", | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a valid HTTP upstream", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{validHTTP}, | ||||||
|  | 				expectedUpstreams: Upstreams{validHTTPUpstream}, | ||||||
|  | 				errMsg:            "", | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a HTTP upstream with an empty path", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{emptyPathHTTP}, | ||||||
|  | 				expectedUpstreams: Upstreams{emptyPathHTTPUpstream}, | ||||||
|  | 				errMsg:            "", | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a valid File upstream with a fragment", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{validFileWithFragment}, | ||||||
|  | 				expectedUpstreams: Upstreams{validFileWithFragmentUpstream}, | ||||||
|  | 				errMsg:            "", | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a valid static upstream", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{validStatic}, | ||||||
|  | 				expectedUpstreams: Upstreams{validStaticUpstream}, | ||||||
|  | 				errMsg:            "", | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid static upstream, code is 200", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{invalidStatic}, | ||||||
|  | 				expectedUpstreams: Upstreams{invalidStaticUpstream}, | ||||||
|  | 				errMsg:            "", | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid HTTP upstream", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{invalidHTTP}, | ||||||
|  | 				expectedUpstreams: Upstreams{}, | ||||||
|  | 				errMsg:            invalidHTTPErrMsg, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid HTTP upstream and other upstreams", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{validHTTP, invalidHTTP}, | ||||||
|  | 				expectedUpstreams: Upstreams{}, | ||||||
|  | 				errMsg:            invalidHTTPErrMsg, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with multiple valid upstreams", &convertUpstreamsTableInput{ | ||||||
|  | 				upstreamStrings:   []string{validHTTP, validFileWithFragment, validStatic}, | ||||||
|  | 				expectedUpstreams: Upstreams{validHTTPUpstream, validFileWithFragmentUpstream, validStaticUpstream}, | ||||||
|  | 				errMsg:            "", | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"os" | 	"os" | ||||||
| 	"testing" |  | ||||||
| 
 | 
 | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | @ -12,11 +11,6 @@ import ( | ||||||
| 	"github.com/spf13/pflag" | 	"github.com/spf13/pflag" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestOptionsSuite(t *testing.T) { |  | ||||||
| 	RegisterFailHandler(Fail) |  | ||||||
| 	RunSpecs(t, "Options Suite") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| var _ = Describe("Load", func() { | var _ = Describe("Load", func() { | ||||||
| 	Context("with a testOptions structure", func() { | 	Context("with a testOptions structure", func() { | ||||||
| 		type TestOptionSubStruct struct { | 		type TestOptionSubStruct struct { | ||||||
|  | @ -300,6 +294,11 @@ var _ = Describe("Load", func() { | ||||||
| 				input:          &Options{}, | 				input:          &Options{}, | ||||||
| 				expectedOutput: NewOptions(), | 				expectedOutput: NewOptions(), | ||||||
| 			}), | 			}), | ||||||
|  | 			Entry("with an empty LegacyOptions struct, should return default values", &testOptionsTableInput{ | ||||||
|  | 				flagSet:        NewFlagSet, | ||||||
|  | 				input:          &LegacyOptions{}, | ||||||
|  | 				expectedOutput: NewLegacyOptions(), | ||||||
|  | 			}), | ||||||
| 		) | 		) | ||||||
| 	}) | 	}) | ||||||
| }) | }) | ||||||
|  |  | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"crypto" | 	"crypto" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"time" |  | ||||||
| 
 | 
 | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	oidc "github.com/coreos/go-oidc" | ||||||
| 	ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" | 	ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" | ||||||
|  | @ -24,7 +23,6 @@ type Options struct { | ||||||
| 	ProxyPrefix        string   `flag:"proxy-prefix" cfg:"proxy_prefix"` | 	ProxyPrefix        string   `flag:"proxy-prefix" cfg:"proxy_prefix"` | ||||||
| 	PingPath           string   `flag:"ping-path" cfg:"ping_path"` | 	PingPath           string   `flag:"ping-path" cfg:"ping_path"` | ||||||
| 	PingUserAgent      string   `flag:"ping-user-agent" cfg:"ping_user_agent"` | 	PingUserAgent      string   `flag:"ping-user-agent" cfg:"ping_user_agent"` | ||||||
| 	ProxyWebSockets    bool     `flag:"proxy-websockets" cfg:"proxy_websockets"` |  | ||||||
| 	HTTPAddress        string   `flag:"http-address" cfg:"http_address"` | 	HTTPAddress        string   `flag:"http-address" cfg:"http_address"` | ||||||
| 	HTTPSAddress       string   `flag:"https-address" cfg:"https_address"` | 	HTTPSAddress       string   `flag:"https-address" cfg:"https_address"` | ||||||
| 	ReverseProxy       bool     `flag:"reverse-proxy" cfg:"reverse_proxy"` | 	ReverseProxy       bool     `flag:"reverse-proxy" cfg:"reverse_proxy"` | ||||||
|  | @ -64,26 +62,26 @@ type Options struct { | ||||||
| 	Session SessionOptions `cfg:",squash"` | 	Session SessionOptions `cfg:",squash"` | ||||||
| 	Logging Logging        `cfg:",squash"` | 	Logging Logging        `cfg:",squash"` | ||||||
| 
 | 
 | ||||||
| 	Upstreams                     []string      `flag:"upstream" cfg:"upstreams"` | 	// Not used in the legacy config, name not allowed to match an external key (upstreams)
 | ||||||
| 	SkipAuthRegex                 []string      `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | 	// TODO(JoelSpeed): Rename when legacy config is removed
 | ||||||
| 	SkipAuthStripHeaders          bool          `flag:"skip-auth-strip-headers" cfg:"skip_auth_strip_headers"` | 	UpstreamServers Upstreams `cfg:",internal"` | ||||||
| 	SkipJwtBearerTokens           bool          `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` | 
 | ||||||
| 	ExtraJwtIssuers               []string      `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers"` | 	SkipAuthRegex         []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | ||||||
| 	PassBasicAuth                 bool          `flag:"pass-basic-auth" cfg:"pass_basic_auth"` | 	SkipAuthStripHeaders  bool     `flag:"skip-auth-strip-headers" cfg:"skip_auth_strip_headers"` | ||||||
| 	SetBasicAuth                  bool          `flag:"set-basic-auth" cfg:"set_basic_auth"` | 	SkipJwtBearerTokens   bool     `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` | ||||||
| 	PreferEmailToUser             bool          `flag:"prefer-email-to-user" cfg:"prefer_email_to_user"` | 	ExtraJwtIssuers       []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers"` | ||||||
| 	BasicAuthPassword             string        `flag:"basic-auth-password" cfg:"basic_auth_password"` | 	PassBasicAuth         bool     `flag:"pass-basic-auth" cfg:"pass_basic_auth"` | ||||||
| 	PassAccessToken               bool          `flag:"pass-access-token" cfg:"pass_access_token"` | 	SetBasicAuth          bool     `flag:"set-basic-auth" cfg:"set_basic_auth"` | ||||||
| 	PassHostHeader                bool          `flag:"pass-host-header" cfg:"pass_host_header"` | 	PreferEmailToUser     bool     `flag:"prefer-email-to-user" cfg:"prefer_email_to_user"` | ||||||
| 	SkipProviderButton            bool          `flag:"skip-provider-button" cfg:"skip_provider_button"` | 	BasicAuthPassword     string   `flag:"basic-auth-password" cfg:"basic_auth_password"` | ||||||
| 	PassUserHeaders               bool          `flag:"pass-user-headers" cfg:"pass_user_headers"` | 	PassAccessToken       bool     `flag:"pass-access-token" cfg:"pass_access_token"` | ||||||
| 	SSLInsecureSkipVerify         bool          `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` | 	SkipProviderButton    bool     `flag:"skip-provider-button" cfg:"skip_provider_button"` | ||||||
| 	SSLUpstreamInsecureSkipVerify bool          `flag:"ssl-upstream-insecure-skip-verify" cfg:"ssl_upstream_insecure_skip_verify"` | 	PassUserHeaders       bool     `flag:"pass-user-headers" cfg:"pass_user_headers"` | ||||||
| 	SetXAuthRequest               bool          `flag:"set-xauthrequest" cfg:"set_xauthrequest"` | 	SSLInsecureSkipVerify bool     `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` | ||||||
| 	SetAuthorization              bool          `flag:"set-authorization-header" cfg:"set_authorization_header"` | 	SetXAuthRequest       bool     `flag:"set-xauthrequest" cfg:"set_xauthrequest"` | ||||||
| 	PassAuthorization             bool          `flag:"pass-authorization-header" cfg:"pass_authorization_header"` | 	SetAuthorization      bool     `flag:"set-authorization-header" cfg:"set_authorization_header"` | ||||||
| 	SkipAuthPreflight             bool          `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` | 	PassAuthorization     bool     `flag:"pass-authorization-header" cfg:"pass_authorization_header"` | ||||||
| 	FlushInterval                 time.Duration `flag:"flush-interval" cfg:"flush_interval"` | 	SkipAuthPreflight     bool     `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` | ||||||
| 
 | 
 | ||||||
| 	// These options allow for other providers besides Google, with
 | 	// These options allow for other providers besides Google, with
 | ||||||
| 	// potential overrides.
 | 	// potential overrides.
 | ||||||
|  | @ -114,7 +112,6 @@ type Options struct { | ||||||
| 
 | 
 | ||||||
| 	// internal values that are set after config validation
 | 	// internal values that are set after config validation
 | ||||||
| 	redirectURL        *url.URL | 	redirectURL        *url.URL | ||||||
| 	proxyURLs          []*url.URL |  | ||||||
| 	compiledRegex      []*regexp.Regexp | 	compiledRegex      []*regexp.Regexp | ||||||
| 	provider           providers.Provider | 	provider           providers.Provider | ||||||
| 	signatureData      *SignatureData | 	signatureData      *SignatureData | ||||||
|  | @ -125,7 +122,6 @@ type Options struct { | ||||||
| 
 | 
 | ||||||
| // Options for Getting internal values
 | // Options for Getting internal values
 | ||||||
| func (o *Options) GetRedirectURL() *url.URL                        { return o.redirectURL } | func (o *Options) GetRedirectURL() *url.URL                        { return o.redirectURL } | ||||||
| func (o *Options) GetProxyURLs() []*url.URL                        { return o.proxyURLs } |  | ||||||
| func (o *Options) GetCompiledRegex() []*regexp.Regexp              { return o.compiledRegex } | func (o *Options) GetCompiledRegex() []*regexp.Regexp              { return o.compiledRegex } | ||||||
| func (o *Options) GetProvider() providers.Provider                 { return o.provider } | func (o *Options) GetProvider() providers.Provider                 { return o.provider } | ||||||
| func (o *Options) GetSignatureData() *SignatureData                { return o.signatureData } | func (o *Options) GetSignatureData() *SignatureData                { return o.signatureData } | ||||||
|  | @ -135,7 +131,6 @@ func (o *Options) GetRealClientIPParser() ipapi.RealClientIPParser { return o.re | ||||||
| 
 | 
 | ||||||
| // Options for Setting internal values
 | // Options for Setting internal values
 | ||||||
| func (o *Options) SetRedirectURL(s *url.URL)                        { o.redirectURL = s } | func (o *Options) SetRedirectURL(s *url.URL)                        { o.redirectURL = s } | ||||||
| func (o *Options) SetProxyURLs(s []*url.URL)                        { o.proxyURLs = s } |  | ||||||
| func (o *Options) SetCompiledRegex(s []*regexp.Regexp)              { o.compiledRegex = s } | func (o *Options) SetCompiledRegex(s []*regexp.Regexp)              { o.compiledRegex = s } | ||||||
| func (o *Options) SetProvider(s providers.Provider)                 { o.provider = s } | func (o *Options) SetProvider(s providers.Provider)                 { o.provider = s } | ||||||
| func (o *Options) SetSignatureData(s *SignatureData)                { o.signatureData = s } | func (o *Options) SetSignatureData(s *SignatureData)                { o.signatureData = s } | ||||||
|  | @ -149,7 +144,6 @@ func NewOptions() *Options { | ||||||
| 		ProxyPrefix:                      "/oauth2", | 		ProxyPrefix:                      "/oauth2", | ||||||
| 		ProviderType:                     "google", | 		ProviderType:                     "google", | ||||||
| 		PingPath:                         "/ping", | 		PingPath:                         "/ping", | ||||||
| 		ProxyWebSockets:                  true, |  | ||||||
| 		HTTPAddress:                      "127.0.0.1:4180", | 		HTTPAddress:                      "127.0.0.1:4180", | ||||||
| 		HTTPSAddress:                     ":443", | 		HTTPSAddress:                     ":443", | ||||||
| 		RealClientIPHeader:               "X-Real-IP", | 		RealClientIPHeader:               "X-Real-IP", | ||||||
|  | @ -160,13 +154,10 @@ func NewOptions() *Options { | ||||||
| 		AzureTenant:                      "common", | 		AzureTenant:                      "common", | ||||||
| 		SetXAuthRequest:                  false, | 		SetXAuthRequest:                  false, | ||||||
| 		SkipAuthPreflight:                false, | 		SkipAuthPreflight:                false, | ||||||
| 		SkipAuthStripHeaders:             false, |  | ||||||
| 		FlushInterval:                    time.Duration(1) * time.Second, |  | ||||||
| 		PassBasicAuth:                    true, | 		PassBasicAuth:                    true, | ||||||
| 		SetBasicAuth:                     false, | 		SetBasicAuth:                     false, | ||||||
| 		PassUserHeaders:                  true, | 		PassUserHeaders:                  true, | ||||||
| 		PassAccessToken:                  false, | 		PassAccessToken:                  false, | ||||||
| 		PassHostHeader:                   true, |  | ||||||
| 		SetAuthorization:                 false, | 		SetAuthorization:                 false, | ||||||
| 		PassAuthorization:                false, | 		PassAuthorization:                false, | ||||||
| 		PreferEmailToUser:                false, | 		PreferEmailToUser:                false, | ||||||
|  | @ -193,14 +184,12 @@ func NewFlagSet() *pflag.FlagSet { | ||||||
| 	flagSet.String("tls-key-file", "", "path to private key file") | 	flagSet.String("tls-key-file", "", "path to private key file") | ||||||
| 	flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") | 	flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") | ||||||
| 	flagSet.Bool("set-xauthrequest", false, "set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)") | 	flagSet.Bool("set-xauthrequest", false, "set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)") | ||||||
| 	flagSet.StringSlice("upstream", []string{}, "the http url(s) of the upstream endpoint, file:// paths for static files or static://<status_code> for static response. Routing is based on the path") |  | ||||||
| 	flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") | 	flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") | ||||||
| 	flagSet.Bool("set-basic-auth", false, "set HTTP Basic Auth information in response (useful in Nginx auth_request mode)") | 	flagSet.Bool("set-basic-auth", false, "set HTTP Basic Auth information in response (useful in Nginx auth_request mode)") | ||||||
| 	flagSet.Bool("prefer-email-to-user", false, "Prefer to use the Email address as the Username when passing information to upstream. Will only use Username if Email is unavailable, eg. htaccess authentication. Used in conjunction with -pass-basic-auth and -pass-user-headers") | 	flagSet.Bool("prefer-email-to-user", false, "Prefer to use the Email address as the Username when passing information to upstream. Will only use Username if Email is unavailable, eg. htaccess authentication. Used in conjunction with -pass-basic-auth and -pass-user-headers") | ||||||
| 	flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream") | 	flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream") | ||||||
| 	flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") | 	flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") | ||||||
| 	flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") | 	flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") | ||||||
| 	flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") |  | ||||||
| 	flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream") | 	flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream") | ||||||
| 	flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)") | 	flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)") | ||||||
| 	flagSet.StringSlice("skip-auth-regex", []string{}, "bypass authentication for requests path's that match (may be given multiple times)") | 	flagSet.StringSlice("skip-auth-regex", []string{}, "bypass authentication for requests path's that match (may be given multiple times)") | ||||||
|  | @ -208,8 +197,6 @@ func NewFlagSet() *pflag.FlagSet { | ||||||
| 	flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") | 	flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") | ||||||
| 	flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") | 	flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") | ||||||
| 	flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS providers") | 	flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS providers") | ||||||
| 	flagSet.Bool("ssl-upstream-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS upstreams") |  | ||||||
| 	flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses") |  | ||||||
| 	flagSet.Bool("skip-jwt-bearer-tokens", false, "will skip requests that have verified JWT bearer tokens (default false)") | 	flagSet.Bool("skip-jwt-bearer-tokens", false, "will skip requests that have verified JWT bearer tokens (default false)") | ||||||
| 	flagSet.StringSlice("extra-jwt-issuers", []string{}, "if skip-jwt-bearer-tokens is set, a list of extra JWT issuer=audience pairs (where the issuer URL has a .well-known/openid-configuration or a .well-known/jwks.json)") | 	flagSet.StringSlice("extra-jwt-issuers", []string{}, "if skip-jwt-bearer-tokens is set, a list of extra JWT issuer=audience pairs (where the issuer URL has a .well-known/openid-configuration or a .well-known/jwks.json)") | ||||||
| 
 | 
 | ||||||
|  | @ -240,7 +227,6 @@ func NewFlagSet() *pflag.FlagSet { | ||||||
| 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | ||||||
| 	flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") | 	flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") | ||||||
| 	flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") | 	flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") | ||||||
| 	flagSet.Bool("proxy-websockets", true, "enables WebSocket proxying") |  | ||||||
| 	flagSet.String("session-store-type", "cookie", "the session storage provider to use") | 	flagSet.String("session-store-type", "cookie", "the session storage provider to use") | ||||||
| 	flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") | 	flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") | ||||||
| 	flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") | 	flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") | ||||||
|  | @ -280,6 +266,7 @@ func NewFlagSet() *pflag.FlagSet { | ||||||
| 
 | 
 | ||||||
| 	flagSet.AddFlagSet(cookieFlagSet()) | 	flagSet.AddFlagSet(cookieFlagSet()) | ||||||
| 	flagSet.AddFlagSet(loggingFlagSet()) | 	flagSet.AddFlagSet(loggingFlagSet()) | ||||||
|  | 	flagSet.AddFlagSet(legacyUpstreamsFlagSet()) | ||||||
| 
 | 
 | ||||||
| 	return flagSet | 	return flagSet | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,16 @@ | ||||||
|  | package options | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestOptionsSuite(t *testing.T) { | ||||||
|  | 	logger.SetOutput(GinkgoWriter) | ||||||
|  | 
 | ||||||
|  | 	RegisterFailHandler(Fail) | ||||||
|  | 	RunSpecs(t, "Options Suite") | ||||||
|  | } | ||||||
|  | @ -52,9 +52,9 @@ type Upstream struct { | ||||||
| 	// PassHostHeader determines whether the request host header should be proxied
 | 	// PassHostHeader determines whether the request host header should be proxied
 | ||||||
| 	// to the upstream server.
 | 	// to the upstream server.
 | ||||||
| 	// Defaults to true.
 | 	// Defaults to true.
 | ||||||
| 	PassHostHeader bool `json:"passHostHeader"` | 	PassHostHeader *bool `json:"passHostHeader"` | ||||||
| 
 | 
 | ||||||
| 	// ProxyWebSockets enables proxying of websockets to upstream servers
 | 	// ProxyWebSockets enables proxying of websockets to upstream servers
 | ||||||
| 	// Defaults to true.
 | 	// Defaults to true.
 | ||||||
| 	ProxyWebSockets bool `json:"proxyWebSockets"` | 	ProxyWebSockets *bool `json:"proxyWebSockets"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -49,7 +49,7 @@ func newHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *option | ||||||
| 
 | 
 | ||||||
| 	// Set up a WebSocket proxy if required
 | 	// Set up a WebSocket proxy if required
 | ||||||
| 	var wsProxy http.Handler | 	var wsProxy http.Handler | ||||||
| 	if upstream.ProxyWebSockets { | 	if upstream.ProxyWebSockets == nil || *upstream.ProxyWebSockets { | ||||||
| 		wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify) | 		wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -110,7 +110,7 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Set the request director based on the PassHostHeader option
 | 	// Set the request director based on the PassHostHeader option
 | ||||||
| 	if !upstream.PassHostHeader { | 	if upstream.PassHostHeader != nil && !*upstream.PassHostHeader { | ||||||
| 		setProxyUpstreamHostHeader(proxy, target) | 		setProxyUpstreamHostHeader(proxy, target) | ||||||
| 	} else { | 	} else { | ||||||
| 		setProxyDirector(proxy) | 		setProxyDirector(proxy) | ||||||
|  |  | ||||||
|  | @ -24,6 +24,8 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 
 | 
 | ||||||
| 	const flushInterval5s = 5 * time.Second | 	const flushInterval5s = 5 * time.Second | ||||||
| 	const flushInterval1s = 1 * time.Second | 	const flushInterval1s = 1 * time.Second | ||||||
|  | 	truth := true | ||||||
|  | 	falsum := false | ||||||
| 
 | 
 | ||||||
| 	type httpUpstreamTableInput struct { | 	type httpUpstreamTableInput struct { | ||||||
| 		id               string | 		id               string | ||||||
|  | @ -51,10 +53,11 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			rw := httptest.NewRecorder() | 			rw := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| 			flush := 1 * time.Second | 			flush := 1 * time.Second | ||||||
|  | 
 | ||||||
| 			upstream := options.Upstream{ | 			upstream := options.Upstream{ | ||||||
| 				ID:                    in.id, | 				ID:                    in.id, | ||||||
| 				PassHostHeader:        true, | 				PassHostHeader:        &truth, | ||||||
| 				ProxyWebSockets:       false, | 				ProxyWebSockets:       &falsum, | ||||||
| 				InsecureSkipTLSVerify: false, | 				InsecureSkipTLSVerify: false, | ||||||
| 				FlushInterval:         &flush, | 				FlushInterval:         &flush, | ||||||
| 			} | 			} | ||||||
|  | @ -258,8 +261,8 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 		flush := 1 * time.Second | 		flush := 1 * time.Second | ||||||
| 		upstream := options.Upstream{ | 		upstream := options.Upstream{ | ||||||
| 			ID:                    "noPassHost", | 			ID:                    "noPassHost", | ||||||
| 			PassHostHeader:        false, | 			PassHostHeader:        &falsum, | ||||||
| 			ProxyWebSockets:       false, | 			ProxyWebSockets:       &falsum, | ||||||
| 			InsecureSkipTLSVerify: false, | 			InsecureSkipTLSVerify: false, | ||||||
| 			FlushInterval:         &flush, | 			FlushInterval:         &flush, | ||||||
| 		} | 		} | ||||||
|  | @ -302,7 +305,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 				ID:                    "foo123", | 				ID:                    "foo123", | ||||||
| 				FlushInterval:         &in.flushInterval, | 				FlushInterval:         &in.flushInterval, | ||||||
| 				InsecureSkipTLSVerify: in.skipVerify, | 				InsecureSkipTLSVerify: in.skipVerify, | ||||||
| 				ProxyWebSockets:       in.proxyWebSockets, | 				ProxyWebSockets:       &in.proxyWebSockets, | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler) | 			handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler) | ||||||
|  | @ -370,8 +373,8 @@ var _ = Describe("HTTP Upstream Suite", func() { | ||||||
| 			flush := 1 * time.Second | 			flush := 1 * time.Second | ||||||
| 			upstream := options.Upstream{ | 			upstream := options.Upstream{ | ||||||
| 				ID:                    "websocketProxy", | 				ID:                    "websocketProxy", | ||||||
| 				PassHostHeader:        true, | 				PassHostHeader:        &truth, | ||||||
| 				ProxyWebSockets:       true, | 				ProxyWebSockets:       &truth, | ||||||
| 				InsecureSkipTLSVerify: false, | 				InsecureSkipTLSVerify: false, | ||||||
| 				FlushInterval:         &flush, | 				FlushInterval:         &flush, | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | @ -176,17 +176,7 @@ func Validate(o *options.Options) error { | ||||||
| 	redirectURL, msgs = parseURL(o.RawRedirectURL, "redirect", msgs) | 	redirectURL, msgs = parseURL(o.RawRedirectURL, "redirect", msgs) | ||||||
| 	o.SetRedirectURL(redirectURL) | 	o.SetRedirectURL(redirectURL) | ||||||
| 
 | 
 | ||||||
| 	for _, u := range o.Upstreams { | 	msgs = append(msgs, validateUpstreams(o.UpstreamServers)...) | ||||||
| 		upstreamURL, err := url.Parse(u) |  | ||||||
| 		if err != nil { |  | ||||||
| 			msgs = append(msgs, fmt.Sprintf("error parsing upstream: %s", err)) |  | ||||||
| 		} else { |  | ||||||
| 			if upstreamURL.Path == "" { |  | ||||||
| 				upstreamURL.Path = "/" |  | ||||||
| 			} |  | ||||||
| 			o.SetProxyURLs(append(o.GetProxyURLs(), upstreamURL)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	for _, u := range o.SkipAuthRegex { | 	for _, u := range o.SkipAuthRegex { | ||||||
| 		compiledRegex, err := regexp.Compile(u) | 		compiledRegex, err := regexp.Compile(u) | ||||||
|  |  | ||||||
|  | @ -22,7 +22,11 @@ const ( | ||||||
| 
 | 
 | ||||||
| func testOptions() *options.Options { | func testOptions() *options.Options { | ||||||
| 	o := options.NewOptions() | 	o := options.NewOptions() | ||||||
| 	o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8080/") | 	o.UpstreamServers = append(o.UpstreamServers, options.Upstream{ | ||||||
|  | 		ID:   "upstream", | ||||||
|  | 		Path: "/", | ||||||
|  | 		URI:  "http://127.0.0.1:8080/", | ||||||
|  | 	}) | ||||||
| 	o.Cookie.Secret = cookieSecret | 	o.Cookie.Secret = cookieSecret | ||||||
| 	o.ClientID = clientID | 	o.ClientID = clientID | ||||||
| 	o.ClientSecret = clientSecret | 	o.ClientSecret = clientSecret | ||||||
|  | @ -140,26 +144,6 @@ func TestRedirectURL(t *testing.T) { | ||||||
| 	assert.Equal(t, expected, o.GetRedirectURL()) | 	assert.Equal(t, expected, o.GetRedirectURL()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProxyURLs(t *testing.T) { |  | ||||||
| 	o := testOptions() |  | ||||||
| 	o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") |  | ||||||
| 	assert.Equal(t, nil, Validate(o)) |  | ||||||
| 	expected := []*url.URL{ |  | ||||||
| 		{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, |  | ||||||
| 		// note the '/' was added
 |  | ||||||
| 		{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, |  | ||||||
| 	} |  | ||||||
| 	assert.Equal(t, expected, o.GetProxyURLs()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestProxyURLsError(t *testing.T) { |  | ||||||
| 	o := testOptions() |  | ||||||
| 	o.Upstreams = append(o.Upstreams, "127.0.0.1:8081") |  | ||||||
| 	err := Validate(o) |  | ||||||
| 	assert.NotEqual(t, nil, err) |  | ||||||
| 	assert.Contains(t, err.Error(), "error parsing upstream") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestCompiledRegex(t *testing.T) { | func TestCompiledRegex(t *testing.T) { | ||||||
| 	o := testOptions() | 	o := testOptions() | ||||||
| 	regexps := []string{"/foo/.*", "/ba[rz]/quux"} | 	regexps := []string{"/foo/.*", "/ba[rz]/quux"} | ||||||
|  |  | ||||||
|  | @ -73,10 +73,10 @@ func validateStaticUpstream(upstream options.Upstream) []string { | ||||||
| 	if upstream.FlushInterval != nil && *upstream.FlushInterval != time.Second { | 	if upstream.FlushInterval != nil && *upstream.FlushInterval != time.Second { | ||||||
| 		msgs = append(msgs, fmt.Sprintf("upstream %q has flushInterval, but is a static upstream, this will have no effect.", upstream.ID)) | 		msgs = append(msgs, fmt.Sprintf("upstream %q has flushInterval, but is a static upstream, this will have no effect.", upstream.ID)) | ||||||
| 	} | 	} | ||||||
| 	if !upstream.PassHostHeader { | 	if upstream.PassHostHeader != nil { | ||||||
| 		msgs = append(msgs, fmt.Sprintf("upstream %q has passHostHeader, but is a static upstream, this will have no effect.", upstream.ID)) | 		msgs = append(msgs, fmt.Sprintf("upstream %q has passHostHeader, but is a static upstream, this will have no effect.", upstream.ID)) | ||||||
| 	} | 	} | ||||||
| 	if !upstream.ProxyWebSockets { | 	if upstream.ProxyWebSockets != nil { | ||||||
| 		msgs = append(msgs, fmt.Sprintf("upstream %q has proxyWebSockets, but is a static upstream, this will have no effect.", upstream.ID)) | 		msgs = append(msgs, fmt.Sprintf("upstream %q has proxyWebSockets, but is a static upstream, this will have no effect.", upstream.ID)) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ var _ = Describe("Upstreams", func() { | ||||||
| 
 | 
 | ||||||
| 	flushInterval := 5 * time.Second | 	flushInterval := 5 * time.Second | ||||||
| 	staticCode200 := 200 | 	staticCode200 := 200 | ||||||
|  | 	truth := true | ||||||
| 
 | 
 | ||||||
| 	validHTTPUpstream := options.Upstream{ | 	validHTTPUpstream := options.Upstream{ | ||||||
| 		ID:   "validHTTPUpstream", | 		ID:   "validHTTPUpstream", | ||||||
|  | @ -24,11 +25,9 @@ var _ = Describe("Upstreams", func() { | ||||||
| 		URI:  "http://localhost:8080", | 		URI:  "http://localhost:8080", | ||||||
| 	} | 	} | ||||||
| 	validStaticUpstream := options.Upstream{ | 	validStaticUpstream := options.Upstream{ | ||||||
| 		ID:              "validStaticUpstream", | 		ID:     "validStaticUpstream", | ||||||
| 		Path:            "/validStaticUpstream", | 		Path:   "/validStaticUpstream", | ||||||
| 		Static:          true, | 		Static: true, | ||||||
| 		PassHostHeader:  true, // This would normally be defaulted
 |  | ||||||
| 		ProxyWebSockets: true, // this would normally be defaulted
 |  | ||||||
| 	} | 	} | ||||||
| 	validFileUpstream := options.Upstream{ | 	validFileUpstream := options.Upstream{ | ||||||
| 		ID:   "validFileUpstream", | 		ID:   "validFileUpstream", | ||||||
|  | @ -134,8 +133,8 @@ var _ = Describe("Upstreams", func() { | ||||||
| 					URI:                   "ftp://foo", | 					URI:                   "ftp://foo", | ||||||
| 					Static:                true, | 					Static:                true, | ||||||
| 					FlushInterval:         &flushInterval, | 					FlushInterval:         &flushInterval, | ||||||
| 					PassHostHeader:        false, | 					PassHostHeader:        &truth, | ||||||
| 					ProxyWebSockets:       false, | 					ProxyWebSockets:       &truth, | ||||||
| 					InsecureSkipTLSVerify: true, | 					InsecureSkipTLSVerify: true, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue