Merge pull request #360 from jehiah/csrf_validation_360
CSRF protection for OAuth flow.
This commit is contained in:
		
						commit
						4464655276
					
				| 
						 | 
					@ -0,0 +1,16 @@
 | 
				
			||||||
 | 
					package cookie
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func Nonce() (nonce string, err error) {
 | 
				
			||||||
 | 
						b := make([]byte, 16)
 | 
				
			||||||
 | 
						_, err = rand.Read(b)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						nonce = fmt.Sprintf("%x", b)
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -37,6 +37,7 @@ var SignatureHeaders []string = []string{
 | 
				
			||||||
type OAuthProxy struct {
 | 
					type OAuthProxy struct {
 | 
				
			||||||
	CookieSeed     string
 | 
						CookieSeed     string
 | 
				
			||||||
	CookieName     string
 | 
						CookieName     string
 | 
				
			||||||
 | 
						CSRFCookieName string
 | 
				
			||||||
	CookieDomain   string
 | 
						CookieDomain   string
 | 
				
			||||||
	CookieSecure   bool
 | 
						CookieSecure   bool
 | 
				
			||||||
	CookieHttpOnly bool
 | 
						CookieHttpOnly bool
 | 
				
			||||||
| 
						 | 
					@ -174,6 +175,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &OAuthProxy{
 | 
						return &OAuthProxy{
 | 
				
			||||||
		CookieName:     opts.CookieName,
 | 
							CookieName:     opts.CookieName,
 | 
				
			||||||
 | 
							CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
 | 
				
			||||||
		CookieSeed:     opts.CookieSecret,
 | 
							CookieSeed:     opts.CookieSecret,
 | 
				
			||||||
		CookieDomain:   opts.CookieDomain,
 | 
							CookieDomain:   opts.CookieDomain,
 | 
				
			||||||
		CookieSecure:   opts.CookieSecure,
 | 
							CookieSecure:   opts.CookieSecure,
 | 
				
			||||||
| 
						 | 
					@ -245,7 +247,22 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
 | 
					func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
 | 
				
			||||||
 | 
						if value != "" {
 | 
				
			||||||
 | 
							value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
 | 
				
			||||||
 | 
							if len(value) > 4096 {
 | 
				
			||||||
 | 
								// Cookies cannot be larger than 4kb
 | 
				
			||||||
 | 
								log.Printf("WARNING - Cookie Size: %d bytes", len(value))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return p.makeCookie(req, p.CookieName, value, expiration, now)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
 | 
				
			||||||
 | 
						return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
 | 
				
			||||||
	domain := req.Host
 | 
						domain := req.Host
 | 
				
			||||||
	if h, _, err := net.SplitHostPort(domain); err == nil {
 | 
						if h, _, err := net.SplitHostPort(domain); err == nil {
 | 
				
			||||||
		domain = h
 | 
							domain = h
 | 
				
			||||||
| 
						 | 
					@ -257,15 +274,8 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time
 | 
				
			||||||
		domain = p.CookieDomain
 | 
							domain = p.CookieDomain
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if value != "" {
 | 
					 | 
				
			||||||
		value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
 | 
					 | 
				
			||||||
		if len(value) > 4096 {
 | 
					 | 
				
			||||||
			// Cookies cannot be larger than 4kb
 | 
					 | 
				
			||||||
			log.Printf("WARNING - Cookie Size: %d bytes", len(value))
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &http.Cookie{
 | 
						return &http.Cookie{
 | 
				
			||||||
		Name:     p.CookieName,
 | 
							Name:     name,
 | 
				
			||||||
		Value:    value,
 | 
							Value:    value,
 | 
				
			||||||
		Path:     "/",
 | 
							Path:     "/",
 | 
				
			||||||
		Domain:   domain,
 | 
							Domain:   domain,
 | 
				
			||||||
| 
						 | 
					@ -275,12 +285,20 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) {
 | 
					func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
	http.SetCookie(rw, p.MakeCookie(req, "", time.Hour*-1, time.Now()))
 | 
						http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) {
 | 
					func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) {
 | 
				
			||||||
	http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
 | 
						http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now()))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
						http.SetCookie(rw, p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
 | 
				
			||||||
 | 
						http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
 | 
					func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
 | 
				
			||||||
| 
						 | 
					@ -309,7 +327,7 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	p.SetCookie(rw, req, value)
 | 
						p.SetSessionCookie(rw, req, value)
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -339,7 +357,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
 | 
					func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
 | 
				
			||||||
	p.ClearCookie(rw, req)
 | 
						p.ClearSessionCookie(rw, req)
 | 
				
			||||||
	rw.WriteHeader(code)
 | 
						rw.WriteHeader(code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	redirect_url := req.URL.RequestURI()
 | 
						redirect_url := req.URL.RequestURI()
 | 
				
			||||||
| 
						 | 
					@ -384,20 +402,18 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st
 | 
				
			||||||
	return "", false
 | 
						return "", false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) GetRedirect(req *http.Request) (string, error) {
 | 
					func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
 | 
				
			||||||
	err := req.ParseForm()
 | 
						err = req.ParseForm()
 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	redirect := req.FormValue("rd")
 | 
						redirect = req.Form.Get("rd")
 | 
				
			||||||
 | 
						if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
 | 
				
			||||||
	if redirect == "" {
 | 
					 | 
				
			||||||
		redirect = "/"
 | 
							redirect = "/"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return redirect, err
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
 | 
					func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
 | 
				
			||||||
| 
						 | 
					@ -459,18 +475,24 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
 | 
					func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
	p.ClearCookie(rw, req)
 | 
						p.ClearSessionCookie(rw, req)
 | 
				
			||||||
	http.Redirect(rw, req, "/", 302)
 | 
						http.Redirect(rw, req, "/", 302)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
 | 
					func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
 | 
						nonce, err := cookie.Nonce()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							p.ErrorPage(rw, 500, "Internal Error", err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						p.SetCSRFCookie(rw, req, nonce)
 | 
				
			||||||
	redirect, err := p.GetRedirect(req)
 | 
						redirect, err := p.GetRedirect(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		p.ErrorPage(rw, 500, "Internal Error", err.Error())
 | 
							p.ErrorPage(rw, 500, "Internal Error", err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	redirectURI := p.GetRedirectURI(req.Host)
 | 
						redirectURI := p.GetRedirectURI(req.Host)
 | 
				
			||||||
	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
 | 
						http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
 | 
					func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
| 
						 | 
					@ -495,7 +517,25 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	redirect := req.Form.Get("state")
 | 
						s := strings.SplitN(req.Form.Get("state"), ":", 2)
 | 
				
			||||||
 | 
						if len(s) != 2 {
 | 
				
			||||||
 | 
							p.ErrorPage(rw, 500, "Internal Error", "Invalid State")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						nonce := s[0]
 | 
				
			||||||
 | 
						redirect := s[1]
 | 
				
			||||||
 | 
						c, err := req.Cookie(p.CSRFCookieName)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							p.ErrorPage(rw, 403, "Permission Denied", err.Error())
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						p.ClearCSRFCookie(rw, req)
 | 
				
			||||||
 | 
						if c.Value != nonce {
 | 
				
			||||||
 | 
							log.Printf("%s csrf token mismatch, potential attack", remoteAddr)
 | 
				
			||||||
 | 
							p.ErrorPage(rw, 403, "Permission Denied", "csrf failed")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
 | 
						if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
 | 
				
			||||||
		redirect = "/"
 | 
							redirect = "/"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -595,7 +635,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if clearSession {
 | 
						if clearSession {
 | 
				
			||||||
		p.ClearCookie(rw, req)
 | 
							p.ClearSessionCookie(rw, req)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if session == nil {
 | 
						if session == nil {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -170,10 +170,14 @@ func TestBasicAuthPassword(t *testing.T) {
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rw := httptest.NewRecorder()
 | 
						rw := httptest.NewRecorder()
 | 
				
			||||||
	req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
 | 
						req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
 | 
				
			||||||
		strings.NewReader(""))
 | 
							strings.NewReader(""))
 | 
				
			||||||
 | 
						req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
 | 
				
			||||||
	proxy.ServeHTTP(rw, req)
 | 
						proxy.ServeHTTP(rw, req)
 | 
				
			||||||
	cookie := rw.HeaderMap["Set-Cookie"][0]
 | 
						if rw.Code >= 400 {
 | 
				
			||||||
 | 
							t.Fatalf("expected 3xx got %d", rw.Code)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						cookie := rw.HeaderMap["Set-Cookie"][1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	cookieName := proxy.CookieName
 | 
						cookieName := proxy.CookieName
 | 
				
			||||||
	var value string
 | 
						var value string
 | 
				
			||||||
| 
						 | 
					@ -196,9 +200,11 @@ func TestBasicAuthPassword(t *testing.T) {
 | 
				
			||||||
		Expires:  time.Now().Add(time.Duration(24)),
 | 
							Expires:  time.Now().Add(time.Duration(24)),
 | 
				
			||||||
		HttpOnly: true,
 | 
							HttpOnly: true,
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rw = httptest.NewRecorder()
 | 
						rw = httptest.NewRecorder()
 | 
				
			||||||
	proxy.ServeHTTP(rw, req)
 | 
						proxy.ServeHTTP(rw, req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
 | 
						expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
 | 
				
			||||||
	assert.Equal(t, expectedHeader, rw.Body.String())
 | 
						assert.Equal(t, expectedHeader, rw.Body.String())
 | 
				
			||||||
	provider_server.Close()
 | 
						provider_server.Close()
 | 
				
			||||||
| 
						 | 
					@ -263,13 +269,14 @@ func (pat_test *PassAccessTokenTest) Close() {
 | 
				
			||||||
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
 | 
					func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
 | 
				
			||||||
	cookie string) {
 | 
						cookie string) {
 | 
				
			||||||
	rw := httptest.NewRecorder()
 | 
						rw := httptest.NewRecorder()
 | 
				
			||||||
	req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
 | 
						req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
 | 
				
			||||||
		strings.NewReader(""))
 | 
							strings.NewReader(""))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return 0, ""
 | 
							return 0, ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
 | 
				
			||||||
	pat_test.proxy.ServeHTTP(rw, req)
 | 
						pat_test.proxy.ServeHTTP(rw, req)
 | 
				
			||||||
	return rw.Code, rw.HeaderMap["Set-Cookie"][0]
 | 
						return rw.Code, rw.HeaderMap["Set-Cookie"][1]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
 | 
					func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
 | 
				
			||||||
| 
						 | 
					@ -314,14 +321,18 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// A successful validation will redirect and set the auth cookie.
 | 
						// A successful validation will redirect and set the auth cookie.
 | 
				
			||||||
	code, cookie := pat_test.getCallbackEndpoint()
 | 
						code, cookie := pat_test.getCallbackEndpoint()
 | 
				
			||||||
	assert.Equal(t, 302, code)
 | 
						if code != 302 {
 | 
				
			||||||
 | 
							t.Fatalf("expected 302; got %d", code)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	assert.NotEqual(t, nil, cookie)
 | 
						assert.NotEqual(t, nil, cookie)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Now we make a regular request; the access_token from the cookie is
 | 
						// Now we make a regular request; the access_token from the cookie is
 | 
				
			||||||
	// forwarded as the "X-Forwarded-Access-Token" header. The token is
 | 
						// forwarded as the "X-Forwarded-Access-Token" header. The token is
 | 
				
			||||||
	// read by the test provider server and written in the response body.
 | 
						// read by the test provider server and written in the response body.
 | 
				
			||||||
	code, payload := pat_test.getRootEndpoint(cookie)
 | 
						code, payload := pat_test.getRootEndpoint(cookie)
 | 
				
			||||||
	assert.Equal(t, 200, code)
 | 
						if code != 200 {
 | 
				
			||||||
 | 
							t.Fatalf("expected 200; got %d", code)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	assert.Equal(t, "my_auth_token", payload)
 | 
						assert.Equal(t, "my_auth_token", payload)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -333,13 +344,17 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// A successful validation will redirect and set the auth cookie.
 | 
						// A successful validation will redirect and set the auth cookie.
 | 
				
			||||||
	code, cookie := pat_test.getCallbackEndpoint()
 | 
						code, cookie := pat_test.getCallbackEndpoint()
 | 
				
			||||||
	assert.Equal(t, 302, code)
 | 
						if code != 302 {
 | 
				
			||||||
 | 
							t.Fatalf("expected 302; got %d", code)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	assert.NotEqual(t, nil, cookie)
 | 
						assert.NotEqual(t, nil, cookie)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Now we make a regular request, but the access token header should
 | 
						// Now we make a regular request, but the access token header should
 | 
				
			||||||
	// not be present.
 | 
						// not be present.
 | 
				
			||||||
	code, payload := pat_test.getRootEndpoint(cookie)
 | 
						code, payload := pat_test.getRootEndpoint(cookie)
 | 
				
			||||||
	assert.Equal(t, 200, code)
 | 
						if code != 200 {
 | 
				
			||||||
 | 
							t.Fatalf("expected 200; got %d", code)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	assert.Equal(t, "No access token found.", payload)
 | 
						assert.Equal(t, "No access token found.", payload)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -457,7 +472,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
 | 
					func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
 | 
				
			||||||
	return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref)
 | 
						return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
 | 
					func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
 | 
				
			||||||
| 
						 | 
					@ -465,7 +480,7 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref))
 | 
						p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref))
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -697,7 +712,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		panic(err)
 | 
							panic(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	cookie := proxy.MakeCookie(req, value, proxy.CookieExpire, time.Now())
 | 
						cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now())
 | 
				
			||||||
	req.AddCookie(cookie)
 | 
						req.AddCookie(cookie)
 | 
				
			||||||
	// This is used by the upstream to validate the signature.
 | 
						// This is used by the upstream to validate the signature.
 | 
				
			||||||
	st.authenticator.auth = hmacauth.NewHmacAuth(
 | 
						st.authenticator.auth = hmacauth.NewHmacAuth(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,7 +8,6 @@ import (
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/bitly/oauth2_proxy/cookie"
 | 
						"github.com/bitly/oauth2_proxy/cookie"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -79,7 +78,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetLoginURL with typical oauth parameters
 | 
					// GetLoginURL with typical oauth parameters
 | 
				
			||||||
func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
 | 
					func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
 | 
				
			||||||
	var a url.URL
 | 
						var a url.URL
 | 
				
			||||||
	a = *p.LoginURL
 | 
						a = *p.LoginURL
 | 
				
			||||||
	params, _ := url.ParseQuery(a.RawQuery)
 | 
						params, _ := url.ParseQuery(a.RawQuery)
 | 
				
			||||||
| 
						 | 
					@ -88,9 +87,7 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
 | 
				
			||||||
	params.Add("scope", p.Scope)
 | 
						params.Add("scope", p.Scope)
 | 
				
			||||||
	params.Set("client_id", p.ClientID)
 | 
						params.Set("client_id", p.ClientID)
 | 
				
			||||||
	params.Set("response_type", "code")
 | 
						params.Set("response_type", "code")
 | 
				
			||||||
	if strings.HasPrefix(finalRedirect, "/") && !strings.HasPrefix(finalRedirect,"//") {
 | 
						params.Add("state", state)
 | 
				
			||||||
		params.Add("state", finalRedirect)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	a.RawQuery = params.Encode()
 | 
						a.RawQuery = params.Encode()
 | 
				
			||||||
	return a.String()
 | 
						return a.String()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue