Merge pull request #80 from 18F/pass-access-token
Pass the access token to the upstream server
This commit is contained in:
		
						commit
						864d4787e9
					
				|  | @ -76,6 +76,7 @@ Usage of google_auth_proxy: | ||||||
|   -htpasswd-file="": additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption |   -htpasswd-file="": additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption | ||||||
|   -http-address="127.0.0.1:4180": [http://]<addr>:<port> or unix://<path> to listen on for HTTP clients |   -http-address="127.0.0.1:4180": [http://]<addr>:<port> or unix://<path> to listen on for HTTP clients | ||||||
|   -login-url="": Authentication endpoint |   -login-url="": Authentication endpoint | ||||||
|  |   -pass-access-token=false: pass OAuth access_token to upstream via X-Forwarded-Access-Token header | ||||||
|   -pass-basic-auth=true: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream |   -pass-basic-auth=true: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream | ||||||
|   -pass-host-header=true: pass the request Host Header to upstream |   -pass-host-header=true: pass the request Host Header to upstream | ||||||
|   -profile-url="": Profile access endpoint |   -profile-url="": Profile access endpoint | ||||||
|  |  | ||||||
							
								
								
									
										38
									
								
								cookies.go
								
								
								
								
							
							
						
						
									
										38
									
								
								cookies.go
								
								
								
								
							|  | @ -1,10 +1,14 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"crypto/aes" | ||||||
|  | 	"crypto/cipher" | ||||||
| 	"crypto/hmac" | 	"crypto/hmac" | ||||||
|  | 	"crypto/rand" | ||||||
| 	"crypto/sha1" | 	"crypto/sha1" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | @ -59,3 +63,37 @@ func checkHmac(input, expected string) bool { | ||||||
| 	} | 	} | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func encodeAccessToken(aes_cipher cipher.Block, access_token string) (string, error) { | ||||||
|  | 	ciphertext := make([]byte, aes.BlockSize+len(access_token)) | ||||||
|  | 	iv := ciphertext[:aes.BlockSize] | ||||||
|  | 	if _, err := io.ReadFull(rand.Reader, iv); err != nil { | ||||||
|  | 		return "", fmt.Errorf("failed to create access code initialization vector") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	stream := cipher.NewCFBEncrypter(aes_cipher, iv) | ||||||
|  | 	stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(access_token)) | ||||||
|  | 	return base64.StdEncoding.EncodeToString(ciphertext), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (string, error) { | ||||||
|  | 	encrypted_access_token, err := base64.StdEncoding.DecodeString( | ||||||
|  | 		encoded_access_token) | ||||||
|  | 
 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("failed to decode access token") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(encrypted_access_token) < aes.BlockSize { | ||||||
|  | 		return "", fmt.Errorf("encrypted access token should be "+ | ||||||
|  | 			"at least %d bytes, but is only %d bytes", | ||||||
|  | 			aes.BlockSize, len(encrypted_access_token)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	iv := encrypted_access_token[:aes.BlockSize] | ||||||
|  | 	encrypted_access_token = encrypted_access_token[aes.BlockSize:] | ||||||
|  | 	stream := cipher.NewCFBDecrypter(aes_cipher, iv) | ||||||
|  | 	stream.XORKeyStream(encrypted_access_token, encrypted_access_token) | ||||||
|  | 
 | ||||||
|  | 	return string(encrypted_access_token), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,23 @@ | ||||||
|  | package main | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/aes" | ||||||
|  | 	"github.com/bmizerany/assert" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestEncodeAndDecodeAccessToken(t *testing.T) { | ||||||
|  | 	const key = "0123456789abcdefghijklmnopqrstuv" | ||||||
|  | 	const access_token = "my access token" | ||||||
|  | 	c, err := aes.NewCipher([]byte(key)) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 	encoded_token, err := encodeAccessToken(c, access_token) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 	decoded_token, err := decodeAccessToken(c, encoded_token) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
|  | 
 | ||||||
|  | 	assert.NotEqual(t, access_token, encoded_token) | ||||||
|  | 	assert.Equal(t, access_token, decoded_token) | ||||||
|  | } | ||||||
							
								
								
									
										1
									
								
								main.go
								
								
								
								
							
							
						
						
									
										1
									
								
								main.go
								
								
								
								
							|  | @ -30,6 +30,7 @@ func main() { | ||||||
| 	flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") | 	flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") | ||||||
| 	flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint. If multiple, routing is based on path") | 	flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint. If multiple, routing is based on path") | ||||||
| 	flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") | 	flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") | ||||||
|  | 	flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") | ||||||
| 	flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") | 	flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") | ||||||
| 	flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") | 	flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -2,6 +2,8 @@ package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"crypto/aes" | ||||||
|  | 	"crypto/cipher" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | @ -45,6 +47,8 @@ type OauthProxy struct { | ||||||
| 	DisplayHtpasswdForm bool | 	DisplayHtpasswdForm bool | ||||||
| 	serveMux            http.Handler | 	serveMux            http.Handler | ||||||
| 	PassBasicAuth       bool | 	PassBasicAuth       bool | ||||||
|  | 	PassAccessToken     bool | ||||||
|  | 	AesCipher           cipher.Block | ||||||
| 	skipAuthRegex       []string | 	skipAuthRegex       []string | ||||||
| 	compiledRegex       []*regexp.Regexp | 	compiledRegex       []*regexp.Regexp | ||||||
| 	templates           *template.Template | 	templates           *template.Template | ||||||
|  | @ -116,6 +120,29 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | ||||||
| 
 | 
 | ||||||
| 	log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain) | 	log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain) | ||||||
| 
 | 
 | ||||||
|  | 	var aes_cipher cipher.Block | ||||||
|  | 
 | ||||||
|  | 	if opts.PassAccessToken == true { | ||||||
|  | 		valid_cookie_secret_size := false | ||||||
|  | 		for _, i := range []int{16, 24, 32} { | ||||||
|  | 			if len(opts.CookieSecret) == i { | ||||||
|  | 				valid_cookie_secret_size = true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if valid_cookie_secret_size == false { | ||||||
|  | 			log.Fatal("cookie_secret must be 16, 24, or 32 bytes " + | ||||||
|  | 				"to create an AES cipher when " + | ||||||
|  | 				"pass_access_token == true") | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		var err error | ||||||
|  | 		aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Fatal("error creating AES cipher with "+ | ||||||
|  | 				"pass_access_token == true: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return &OauthProxy{ | 	return &OauthProxy{ | ||||||
| 		CookieKey:      "_oauthproxy", | 		CookieKey:      "_oauthproxy", | ||||||
| 		CookieSeed:     opts.CookieSecret, | 		CookieSeed:     opts.CookieSecret, | ||||||
|  | @ -136,6 +163,8 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { | ||||||
| 		skipAuthRegex:      opts.SkipAuthRegex, | 		skipAuthRegex:      opts.SkipAuthRegex, | ||||||
| 		compiledRegex:      opts.CompiledRegex, | 		compiledRegex:      opts.CompiledRegex, | ||||||
| 		PassBasicAuth:      opts.PassBasicAuth, | 		PassBasicAuth:      opts.PassBasicAuth, | ||||||
|  | 		PassAccessToken:    opts.PassAccessToken, | ||||||
|  | 		AesCipher:          aes_cipher, | ||||||
| 		templates:          loadTemplates(opts.CustomTemplatesDir), | 		templates:          loadTemplates(opts.CustomTemplatesDir), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | @ -337,6 +366,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	var ok bool | 	var ok bool | ||||||
| 	var user string | 	var user string | ||||||
| 	var email string | 	var email string | ||||||
|  | 	var access_token string | ||||||
| 
 | 
 | ||||||
| 	if req.URL.Path == pingPath { | 	if req.URL.Path == pingPath { | ||||||
| 		p.PingPage(rw) | 		p.PingPage(rw) | ||||||
|  | @ -390,7 +420,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		_, email, err := p.redeemCode(req.Host, req.Form.Get("code")) | 		access_token, email, err = p.redeemCode(req.Host, req.Form.Get("code")) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Printf("%s error redeeming code %s", remoteAddr, err) | 			log.Printf("%s error redeeming code %s", remoteAddr, err) | ||||||
| 			p.ErrorPage(rw, 500, "Internal Error", err.Error()) | 			p.ErrorPage(rw, 500, "Internal Error", err.Error()) | ||||||
|  | @ -405,7 +435,20 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		// set cookie, or deny
 | 		// set cookie, or deny
 | ||||||
| 		if p.Validator(email) { | 		if p.Validator(email) { | ||||||
| 			log.Printf("%s authenticating %s completed", remoteAddr, email) | 			log.Printf("%s authenticating %s completed", remoteAddr, email) | ||||||
|  | 			encoded_token := "" | ||||||
|  | 			if p.PassAccessToken { | ||||||
|  | 				encoded_token, err = encodeAccessToken(p.AesCipher, access_token) | ||||||
|  | 				if err != nil { | ||||||
|  | 					log.Printf("error encoding access token: %s", err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			access_token = "" | ||||||
|  | 
 | ||||||
|  | 			if encoded_token != "" { | ||||||
|  | 				p.SetCookie(rw, req, email+"|"+encoded_token) | ||||||
|  | 			} else { | ||||||
| 				p.SetCookie(rw, req, email) | 				p.SetCookie(rw, req, email) | ||||||
|  | 			} | ||||||
| 			http.Redirect(rw, req, redirect, 302) | 			http.Redirect(rw, req, redirect, 302) | ||||||
| 			return | 			return | ||||||
| 		} else { | 		} else { | ||||||
|  | @ -417,7 +460,16 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		cookie, err := req.Cookie(p.CookieKey) | 		cookie, err := req.Cookie(p.CookieKey) | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			email, ok = validateCookie(cookie, p.CookieSeed) | 			var value string | ||||||
|  | 			value, ok = validateCookie(cookie, p.CookieSeed) | ||||||
|  | 			components := strings.Split(value, "|") | ||||||
|  | 			email = components[0] | ||||||
|  | 			if len(components) == 2 { | ||||||
|  | 				access_token, err = decodeAccessToken(p.AesCipher, components[1]) | ||||||
|  | 				if err != nil { | ||||||
|  | 					log.Printf("error decoding access token: %s", err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 			user = strings.Split(email, "@")[0] | 			user = strings.Split(email, "@")[0] | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | @ -437,6 +489,9 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		req.Header["X-Forwarded-User"] = []string{user} | 		req.Header["X-Forwarded-User"] = []string{user} | ||||||
| 		req.Header["X-Forwarded-Email"] = []string{email} | 		req.Header["X-Forwarded-Email"] = []string{email} | ||||||
| 	} | 	} | ||||||
|  | 	if access_token != "" { | ||||||
|  | 		req.Header["X-Forwarded-Access-Token"] = []string{access_token} | ||||||
|  | 	} | ||||||
| 	if email == "" { | 	if email == "" { | ||||||
| 		rw.Header().Set("GAP-Auth", user) | 		rw.Header().Set("GAP-Auth", user) | ||||||
| 	} else { | 	} else { | ||||||
|  |  | ||||||
|  | @ -1,12 +1,17 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"github.com/bitly/go-simplejson" | ||||||
|  | 	"github.com/bitly/google_auth_proxy/providers" | ||||||
|  | 	"github.com/bmizerany/assert" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestNewReverseProxy(t *testing.T) { | func TestNewReverseProxy(t *testing.T) { | ||||||
|  | @ -18,8 +23,7 @@ func TestNewReverseProxy(t *testing.T) { | ||||||
| 	defer backend.Close() | 	defer backend.Close() | ||||||
| 
 | 
 | ||||||
| 	backendURL, _ := url.Parse(backend.URL) | 	backendURL, _ := url.Parse(backend.URL) | ||||||
| 	backendHostname := "upstream.127.0.0.1.xip.io" | 	backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) | ||||||
| 	_, backendPort, _ := net.SplitHostPort(backendURL.Host) |  | ||||||
| 	backendHost := net.JoinHostPort(backendHostname, backendPort) | 	backendHost := net.JoinHostPort(backendHostname, backendPort) | ||||||
| 	proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") | 	proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") | ||||||
| 
 | 
 | ||||||
|  | @ -61,3 +65,175 @@ func TestEncodedSlashes(t *testing.T) { | ||||||
| 		t.Errorf("got bad request %q expected %q", seen, encodedPath) | 		t.Errorf("got bad request %q expected %q", seen, encodedPath) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type TestProvider struct { | ||||||
|  | 	*providers.ProviderData | ||||||
|  | 	EmailAddress string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json, | ||||||
|  | 	unused_access_token string) (string, error) { | ||||||
|  | 	return tp.EmailAddress, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type PassAccessTokenTest struct { | ||||||
|  | 	provider_server *httptest.Server | ||||||
|  | 	proxy           *OauthProxy | ||||||
|  | 	opts            *Options | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type PassAccessTokenTestOptions struct { | ||||||
|  | 	PassAccessToken bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { | ||||||
|  | 	t := &PassAccessTokenTest{} | ||||||
|  | 
 | ||||||
|  | 	t.provider_server = httptest.NewServer( | ||||||
|  | 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 			url := r.URL | ||||||
|  | 			payload := "" | ||||||
|  | 			switch url.Path { | ||||||
|  | 			case "/oauth/token": | ||||||
|  | 				payload = `{"access_token": "my_auth_token"}` | ||||||
|  | 			default: | ||||||
|  | 				token_header := r.Header["X-Forwarded-Access-Token"] | ||||||
|  | 				if len(token_header) != 0 { | ||||||
|  | 					payload = token_header[0] | ||||||
|  | 				} else { | ||||||
|  | 					payload = "No access token found." | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			w.WriteHeader(200) | ||||||
|  | 			w.Write([]byte(payload)) | ||||||
|  | 		})) | ||||||
|  | 
 | ||||||
|  | 	t.opts = NewOptions() | ||||||
|  | 	t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL) | ||||||
|  | 	// The CookieSecret must be 32 bytes in order to create the AES
 | ||||||
|  | 	// cipher.
 | ||||||
|  | 	t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" | ||||||
|  | 	t.opts.ClientID = "bazquux" | ||||||
|  | 	t.opts.ClientSecret = "foobar" | ||||||
|  | 	t.opts.CookieSecure = false | ||||||
|  | 	t.opts.PassAccessToken = opts.PassAccessToken | ||||||
|  | 	t.opts.Validate() | ||||||
|  | 
 | ||||||
|  | 	provider_url, _ := url.Parse(t.provider_server.URL) | ||||||
|  | 	const email_address = "michael.bland@gsa.gov" | ||||||
|  | 
 | ||||||
|  | 	t.opts.provider = &TestProvider{ | ||||||
|  | 		ProviderData: &providers.ProviderData{ | ||||||
|  | 			ProviderName: "Test Provider", | ||||||
|  | 			LoginUrl: &url.URL{ | ||||||
|  | 				Scheme: "http", | ||||||
|  | 				Host:   provider_url.Host, | ||||||
|  | 				Path:   "/oauth/authorize", | ||||||
|  | 			}, | ||||||
|  | 			RedeemUrl: &url.URL{ | ||||||
|  | 				Scheme: "http", | ||||||
|  | 				Host:   provider_url.Host, | ||||||
|  | 				Path:   "/oauth/token", | ||||||
|  | 			}, | ||||||
|  | 			ProfileUrl: &url.URL{ | ||||||
|  | 				Scheme: "http", | ||||||
|  | 				Host:   provider_url.Host, | ||||||
|  | 				Path:   "/api/v1/profile", | ||||||
|  | 			}, | ||||||
|  | 			Scope: "profile.email", | ||||||
|  | 		}, | ||||||
|  | 		EmailAddress: email_address, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	t.proxy = NewOauthProxy(t.opts, func(email string) bool { | ||||||
|  | 		return email == email_address | ||||||
|  | 	}) | ||||||
|  | 	return t | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func Close(t *PassAccessTokenTest) { | ||||||
|  | 	t.provider_server.Close() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getCallbackEndpoint(pac_test *PassAccessTokenTest) (http_code int, cookie string) { | ||||||
|  | 	rw := httptest.NewRecorder() | ||||||
|  | 	req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code", | ||||||
|  | 		strings.NewReader("")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, "" | ||||||
|  | 	} | ||||||
|  | 	pac_test.proxy.ServeHTTP(rw, req) | ||||||
|  | 	return rw.Code, rw.HeaderMap["Set-Cookie"][0] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code int, | ||||||
|  | 	access_token string) { | ||||||
|  | 	cookie_key := pac_test.proxy.CookieKey | ||||||
|  | 	var value string | ||||||
|  | 	key_prefix := cookie_key + "=" | ||||||
|  | 
 | ||||||
|  | 	for _, field := range strings.Split(cookie, "; ") { | ||||||
|  | 		value = strings.TrimPrefix(field, key_prefix) | ||||||
|  | 		if value != field { | ||||||
|  | 			break | ||||||
|  | 		} else { | ||||||
|  | 			value = "" | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if value == "" { | ||||||
|  | 		return 0, "" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	req, err := http.NewRequest("GET", "/", strings.NewReader("")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, "" | ||||||
|  | 	} | ||||||
|  | 	req.AddCookie(&http.Cookie{ | ||||||
|  | 		Name:     cookie_key, | ||||||
|  | 		Value:    value, | ||||||
|  | 		Path:     "/", | ||||||
|  | 		Expires:  time.Now().Add(time.Duration(24)), | ||||||
|  | 		HttpOnly: true, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	rw := httptest.NewRecorder() | ||||||
|  | 	pac_test.proxy.ServeHTTP(rw, req) | ||||||
|  | 	return rw.Code, rw.Body.String() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestForwardAccessTokenUpstream(t *testing.T) { | ||||||
|  | 	pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | ||||||
|  | 		PassAccessToken: true, | ||||||
|  | 	}) | ||||||
|  | 	defer Close(pac_test) | ||||||
|  | 
 | ||||||
|  | 	// A successful validation will redirect and set the auth cookie.
 | ||||||
|  | 	code, cookie := getCallbackEndpoint(pac_test) | ||||||
|  | 	assert.Equal(t, 302, code) | ||||||
|  | 	assert.NotEqual(t, nil, cookie) | ||||||
|  | 
 | ||||||
|  | 	// Now we make a regular request; the access_token from the cookie is
 | ||||||
|  | 	// forwarded as the "X-Forwarded-Access-Token" header. The token is
 | ||||||
|  | 	// read by the test provider server and written in the response body.
 | ||||||
|  | 	code, payload := getRootEndpoint(pac_test, cookie) | ||||||
|  | 	assert.Equal(t, 200, code) | ||||||
|  | 	assert.Equal(t, "my_auth_token", payload) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | ||||||
|  | 	pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | ||||||
|  | 		PassAccessToken: false, | ||||||
|  | 	}) | ||||||
|  | 	defer Close(pac_test) | ||||||
|  | 
 | ||||||
|  | 	// A successful validation will redirect and set the auth cookie.
 | ||||||
|  | 	code, cookie := getCallbackEndpoint(pac_test) | ||||||
|  | 	assert.Equal(t, 302, code) | ||||||
|  | 	assert.NotEqual(t, nil, cookie) | ||||||
|  | 
 | ||||||
|  | 	// Now we make a regular request, but the access token header should
 | ||||||
|  | 	// not be present.
 | ||||||
|  | 	code, payload := getRootEndpoint(pac_test, cookie) | ||||||
|  | 	assert.Equal(t, 200, code) | ||||||
|  | 	assert.Equal(t, "No access token found.", payload) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -33,6 +33,7 @@ type Options struct { | ||||||
| 	Upstreams       []string `flag:"upstream" cfg:"upstreams"` | 	Upstreams       []string `flag:"upstream" cfg:"upstreams"` | ||||||
| 	SkipAuthRegex   []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | 	SkipAuthRegex   []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | ||||||
| 	PassBasicAuth   bool     `flag:"pass-basic-auth" cfg:"pass_basic_auth"` | 	PassBasicAuth   bool     `flag:"pass-basic-auth" cfg:"pass_basic_auth"` | ||||||
|  | 	PassAccessToken bool     `flag:"pass-access-token" cfg:"pass_access_token"` | ||||||
| 	PassHostHeader  bool     `flag:"pass-host-header" cfg:"pass_host_header"` | 	PassHostHeader  bool     `flag:"pass-host-header" cfg:"pass_host_header"` | ||||||
| 
 | 
 | ||||||
| 	// These options allow for other providers besides Google, with
 | 	// These options allow for other providers besides Google, with
 | ||||||
|  | @ -61,6 +62,7 @@ func NewOptions() *Options { | ||||||
| 		CookieHttpOnly:      true, | 		CookieHttpOnly:      true, | ||||||
| 		CookieExpire:        time.Duration(168) * time.Hour, | 		CookieExpire:        time.Duration(168) * time.Hour, | ||||||
| 		PassBasicAuth:       true, | 		PassBasicAuth:       true, | ||||||
|  | 		PassAccessToken:     false, | ||||||
| 		PassHostHeader:      true, | 		PassHostHeader:      true, | ||||||
| 		RequestLogging:      true, | 		RequestLogging:      true, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue