2357 lines
		
	
	
		
			64 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			2357 lines
		
	
	
		
			64 KiB
		
	
	
	
		
			Go
		
	
	
	
| package main
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"context"
 | |
| 	"crypto"
 | |
| 	"encoding/base64"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"io/ioutil"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"regexp"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/coreos/go-oidc"
 | |
| 	"github.com/mbland/hmacauth"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
 | |
| 	sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/providers"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| 	"golang.org/x/net/websocket"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	// The rawCookieSecret is 32 bytes and the base64CookieSecret is the base64
 | |
| 	// encoded version of this.
 | |
| 	rawCookieSecret    = "secretthirtytwobytes+abcdefghijk"
 | |
| 	base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams"
 | |
| 	clientID           = "3984n253984d7348dm8234yf982t"
 | |
| 	clientSecret       = "gv3498mfc9t23y23974dm2394dm9"
 | |
| )
 | |
| 
 | |
| func init() {
 | |
| 	logger.SetFlags(logger.Lshortfile)
 | |
| }
 | |
| 
 | |
| type WebSocketOrRestHandler struct {
 | |
| 	restHandler http.Handler
 | |
| 	wsHandler   http.Handler
 | |
| }
 | |
| 
 | |
| func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | |
| 	if r.Header.Get("Upgrade") == "websocket" {
 | |
| 		h.wsHandler.ServeHTTP(w, r)
 | |
| 	} else {
 | |
| 		h.restHandler.ServeHTTP(w, r)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestWebSocketProxy(t *testing.T) {
 | |
| 	handler := WebSocketOrRestHandler{
 | |
| 		restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 			w.WriteHeader(200)
 | |
| 			hostname, _, _ := net.SplitHostPort(r.Host)
 | |
| 			_, err := w.Write([]byte(hostname))
 | |
| 			if err != nil {
 | |
| 				t.Fatal(err)
 | |
| 			}
 | |
| 		}),
 | |
| 		wsHandler: websocket.Handler(func(ws *websocket.Conn) {
 | |
| 			defer func(t *testing.T) {
 | |
| 				if err := ws.Close(); err != nil {
 | |
| 					t.Fatal(err)
 | |
| 				}
 | |
| 			}(t)
 | |
| 			var data []byte
 | |
| 			err := websocket.Message.Receive(ws, &data)
 | |
| 			if err != nil {
 | |
| 				t.Fatal(err)
 | |
| 			}
 | |
| 			err = websocket.Message.Send(ws, data)
 | |
| 			if err != nil {
 | |
| 				t.Fatal(err)
 | |
| 			}
 | |
| 		}),
 | |
| 	}
 | |
| 	backend := httptest.NewServer(&handler)
 | |
| 	t.Cleanup(backend.Close)
 | |
| 
 | |
| 	backendURL, _ := url.Parse(backend.URL)
 | |
| 
 | |
| 	opts := baseTestOptions()
 | |
| 	var auth hmacauth.HmacAuth
 | |
| 	opts.PassHostHeader = true
 | |
| 	proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth)
 | |
| 	frontend := httptest.NewServer(proxyHandler)
 | |
| 	t.Cleanup(frontend.Close)
 | |
| 
 | |
| 	frontendURL, _ := url.Parse(frontend.URL)
 | |
| 	frontendWSURL := "ws://" + frontendURL.Host + "/"
 | |
| 
 | |
| 	ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	request := []byte("hello, world!")
 | |
| 	err = websocket.Message.Send(ws, request)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	var response = make([]byte, 1024)
 | |
| 	err = websocket.Message.Receive(ws, &response)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	if g, e := string(request), string(response); g != e {
 | |
| 		t.Errorf("got body %q; expected %q", g, e)
 | |
| 	}
 | |
| 
 | |
| 	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
 | |
| 	res, _ := http.DefaultClient.Do(getReq)
 | |
| 	bodyBytes, _ := ioutil.ReadAll(res.Body)
 | |
| 	backendHostname, _, _ := net.SplitHostPort(backendURL.Host)
 | |
| 	if g, e := string(bodyBytes), backendHostname; g != e {
 | |
| 		t.Errorf("got body %q; expected %q", g, e)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestNewReverseProxy(t *testing.T) {
 | |
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		w.WriteHeader(200)
 | |
| 		hostname, _, _ := net.SplitHostPort(r.Host)
 | |
| 		_, err := w.Write([]byte(hostname))
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}))
 | |
| 	t.Cleanup(backend.Close)
 | |
| 
 | |
| 	backendURL, _ := url.Parse(backend.URL)
 | |
| 	backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
 | |
| 	backendHost := net.JoinHostPort(backendHostname, backendPort)
 | |
| 	proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
 | |
| 
 | |
| 	proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second})
 | |
| 	setProxyUpstreamHostHeader(proxyHandler, proxyURL)
 | |
| 	frontend := httptest.NewServer(proxyHandler)
 | |
| 	t.Cleanup(frontend.Close)
 | |
| 
 | |
| 	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
 | |
| 	res, _ := http.DefaultClient.Do(getReq)
 | |
| 	bodyBytes, _ := ioutil.ReadAll(res.Body)
 | |
| 	if g, e := string(bodyBytes), backendHostname; g != e {
 | |
| 		t.Errorf("got body %q; expected %q", g, e)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestEncodedSlashes(t *testing.T) {
 | |
| 	var seen string
 | |
| 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		w.WriteHeader(200)
 | |
| 		seen = r.RequestURI
 | |
| 	}))
 | |
| 	t.Cleanup(backend.Close)
 | |
| 
 | |
| 	b, _ := url.Parse(backend.URL)
 | |
| 	proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second})
 | |
| 	setProxyDirector(proxyHandler)
 | |
| 	frontend := httptest.NewServer(proxyHandler)
 | |
| 	t.Cleanup(frontend.Close)
 | |
| 
 | |
| 	f, _ := url.Parse(frontend.URL)
 | |
| 	encodedPath := "/a%2Fb/?c=1"
 | |
| 	getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
 | |
| 	_, err := http.DefaultClient.Do(getReq)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	if seen != encodedPath {
 | |
| 		t.Errorf("got bad request %q expected %q", seen, encodedPath)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestRobotsTxt(t *testing.T) {
 | |
| 	opts := baseTestOptions()
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	req, _ := http.NewRequest("GET", "/robots.txt", nil)
 | |
| 	proxy.ServeHTTP(rw, req)
 | |
| 	assert.Equal(t, 200, rw.Code)
 | |
| 	assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
 | |
| }
 | |
| 
 | |
| func TestIsValidRedirect(t *testing.T) {
 | |
| 	opts := baseTestOptions()
 | |
| 	// Should match domains that are exactly foo.bar and any subdomain of bar.foo
 | |
| 	opts.WhitelistDomains = []string{
 | |
| 		"foo.bar",
 | |
| 		".bar.foo",
 | |
| 		"port.bar:8080",
 | |
| 		".sub.port.bar:8080",
 | |
| 		"anyport.bar:*",
 | |
| 		".sub.anyport.bar:*",
 | |
| 	}
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	testCases := []struct {
 | |
| 		Desc, Redirect string
 | |
| 		ExpectedResult bool
 | |
| 	}{
 | |
| 		{
 | |
| 			Desc:           "noRD",
 | |
| 			Redirect:       "",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "singleSlash",
 | |
| 			Redirect:       "/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "doubleSlash",
 | |
| 			Redirect:       "//redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validHTTP",
 | |
| 			Redirect:       "http://foo.bar/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validHTTPS",
 | |
| 			Redirect:       "https://foo.bar/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidHTTPSubdomain",
 | |
| 			Redirect:       "http://baz.foo.bar/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidHTTPSSubdomain",
 | |
| 			Redirect:       "https://baz.foo.bar/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validHTTPSubdomain",
 | |
| 			Redirect:       "http://baz.bar.foo/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validHTTPSSubdomain",
 | |
| 			Redirect:       "https://baz.bar.foo/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validHTTPDomain",
 | |
| 			Redirect:       "http://bar.foo/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidHTTP1",
 | |
| 			Redirect:       "http://foo.bar.evil.corp/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidHTTPS1",
 | |
| 			Redirect:       "https://foo.bar.evil.corp/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidHTTP2",
 | |
| 			Redirect:       "http://evil.corp/redirect?rd=foo.bar",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidHTTPS2",
 | |
| 			Redirect:       "https://evil.corp/redirect?rd=foo.bar",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidPort",
 | |
| 			Redirect:       "https://evil.corp:3838/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidEmptyPort",
 | |
| 			Redirect:       "http://foo.bar:3838/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidEmptyPortSubdomain",
 | |
| 			Redirect:       "http://baz.bar.foo:3838/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validSpecificPort",
 | |
| 			Redirect:       "http://port.bar:8080/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidSpecificPort",
 | |
| 			Redirect:       "http://port.bar:3838/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validSpecificPortSubdomain",
 | |
| 			Redirect:       "http://foo.sub.port.bar:8080/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "invalidSpecificPortSubdomain",
 | |
| 			Redirect:       "http://foo.sub.port.bar:3838/redirect",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validAnyPort1",
 | |
| 			Redirect:       "http://anyport.bar:8080/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validAnyPort2",
 | |
| 			Redirect:       "http://anyport.bar:8081/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validAnyPortSubdomain1",
 | |
| 			Redirect:       "http://a.sub.anyport.bar:8080/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "validAnyPortSubdomain2",
 | |
| 			Redirect:       "http://a.sub.anyport.bar:8081/redirect",
 | |
| 			ExpectedResult: true,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirect1",
 | |
| 			Redirect:       "/\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectSpace1",
 | |
| 			Redirect:       "/ /evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectSpace2",
 | |
| 			Redirect:       "/ \\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectTab1",
 | |
| 			Redirect:       "/\t/evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectTab2",
 | |
| 			Redirect:       "/\t\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectVerticalTab1",
 | |
| 			Redirect:       "/\v/evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectVerticalTab2",
 | |
| 			Redirect:       "/\v\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectNewLine1",
 | |
| 			Redirect:       "/\n/evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectNewLine2",
 | |
| 			Redirect:       "/\n\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectCarriageReturn1",
 | |
| 			Redirect:       "/\r/evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectCarriageReturn2",
 | |
| 			Redirect:       "/\r\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectTripleTab",
 | |
| 			Redirect:       "/\t\t/\t/evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectTripleTab2",
 | |
| 			Redirect:       "/\t\t\\\t/evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectQuadTab1",
 | |
| 			Redirect:       "/\t\t/\t\t\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectQuadTab2",
 | |
| 			Redirect:       "/\t\t\\\t\t/evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectPeriod1",
 | |
| 			Redirect:       "/./\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectPeriod2",
 | |
| 			Redirect:       "/./../../\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 		{
 | |
| 			Desc:           "openRedirectDoubleTab",
 | |
| 			Redirect:       "/\t/\t\\evil.com",
 | |
| 			ExpectedResult: false,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tc := range testCases {
 | |
| 		t.Run(tc.Desc, func(t *testing.T) {
 | |
| 			result := proxy.IsValidRedirect(tc.Redirect)
 | |
| 
 | |
| 			if result != tc.ExpectedResult {
 | |
| 				t.Errorf("expected %t got %t", tc.ExpectedResult, result)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestOpenRedirects(t *testing.T) {
 | |
| 	opts := baseTestOptions()
 | |
| 	// Should match domains that are exactly foo.bar and any subdomain of bar.foo
 | |
| 	opts.WhitelistDomains = []string{
 | |
| 		"foo.bar",
 | |
| 		".bar.foo",
 | |
| 		"port.bar:8080",
 | |
| 		".sub.port.bar:8080",
 | |
| 		"anyport.bar:*",
 | |
| 		".sub.anyport.bar:*",
 | |
| 		"www.whitelisteddomain.tld",
 | |
| 	}
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	file, err := os.Open("./test/openredirects.txt")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	defer func(t *testing.T) {
 | |
| 		if err := file.Close(); err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}(t)
 | |
| 
 | |
| 	scanner := bufio.NewScanner(file)
 | |
| 	for scanner.Scan() {
 | |
| 		rd := scanner.Text()
 | |
| 		t.Run(rd, func(t *testing.T) {
 | |
| 			rdUnescaped, err := url.QueryUnescape(rd)
 | |
| 			if err != nil {
 | |
| 				t.Fatal(err)
 | |
| 			}
 | |
| 			if proxy.IsValidRedirect(rdUnescaped) {
 | |
| 				t.Errorf("Expected %q to not be valid (unescaped: %q)", rd, rdUnescaped)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| 
 | |
| 	if err := scanner.Err(); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type TestProvider struct {
 | |
| 	*providers.ProviderData
 | |
| 	EmailAddress   string
 | |
| 	ValidToken     bool
 | |
| 	GroupValidator func(string) bool
 | |
| }
 | |
| 
 | |
| var _ providers.Provider = (*TestProvider)(nil)
 | |
| 
 | |
| func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
 | |
| 	return &TestProvider{
 | |
| 		ProviderData: &providers.ProviderData{
 | |
| 			ProviderName: "Test Provider",
 | |
| 			LoginURL: &url.URL{
 | |
| 				Scheme: "http",
 | |
| 				Host:   providerURL.Host,
 | |
| 				Path:   "/oauth/authorize",
 | |
| 			},
 | |
| 			RedeemURL: &url.URL{
 | |
| 				Scheme: "http",
 | |
| 				Host:   providerURL.Host,
 | |
| 				Path:   "/oauth/token",
 | |
| 			},
 | |
| 			ProfileURL: &url.URL{
 | |
| 				Scheme: "http",
 | |
| 				Host:   providerURL.Host,
 | |
| 				Path:   "/api/v1/profile",
 | |
| 			},
 | |
| 			Scope: "profile.email",
 | |
| 		},
 | |
| 		EmailAddress: emailAddress,
 | |
| 		GroupValidator: func(s string) bool {
 | |
| 			return true
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (tp *TestProvider) GetEmailAddress(ctx context.Context, session *sessions.SessionState) (string, error) {
 | |
| 	return tp.EmailAddress, nil
 | |
| }
 | |
| 
 | |
| func (tp *TestProvider) ValidateSessionState(ctx context.Context, session *sessions.SessionState) bool {
 | |
| 	return tp.ValidToken
 | |
| }
 | |
| 
 | |
| func TestBasicAuthPassword(t *testing.T) {
 | |
| 	providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		logger.Printf("%#v", r)
 | |
| 		var payload string
 | |
| 		switch r.URL.Path {
 | |
| 		case "/oauth/token":
 | |
| 			payload = `{"access_token": "my_auth_token"}`
 | |
| 		default:
 | |
| 			payload = r.Header.Get("Authorization")
 | |
| 			if payload == "" {
 | |
| 				payload = "No Authorization header found."
 | |
| 			}
 | |
| 		}
 | |
| 		w.WriteHeader(200)
 | |
| 		_, err := w.Write([]byte(payload))
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}))
 | |
| 	opts := baseTestOptions()
 | |
| 	opts.Upstreams = append(opts.Upstreams, providerServer.URL)
 | |
| 	opts.Cookie.Secure = false
 | |
| 	opts.PassBasicAuth = true
 | |
| 	opts.SetBasicAuth = true
 | |
| 	opts.PassUserHeaders = true
 | |
| 	opts.PreferEmailToUser = true
 | |
| 	opts.BasicAuthPassword = "This is a secure password"
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	providerURL, _ := url.Parse(providerServer.URL)
 | |
| 	const emailAddress = "john.doe@example.com"
 | |
| 
 | |
| 	opts.SetProvider(NewTestProvider(providerURL, emailAddress))
 | |
| 	proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | |
| 		return email == emailAddress
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader(""))
 | |
| 	req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
 | |
| 	proxy.ServeHTTP(rw, req)
 | |
| 	if rw.Code >= 400 {
 | |
| 		t.Fatalf("expected 3xx got %d", rw.Code)
 | |
| 	}
 | |
| 	cookie := rw.Header().Values("Set-Cookie")[1]
 | |
| 
 | |
| 	cookieName := proxy.CookieName
 | |
| 	var value string
 | |
| 	keyPrefix := cookieName + "="
 | |
| 
 | |
| 	for _, field := range strings.Split(cookie, "; ") {
 | |
| 		value = strings.TrimPrefix(field, keyPrefix)
 | |
| 		if value != field {
 | |
| 			break
 | |
| 		} else {
 | |
| 			value = ""
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
 | |
| 	req.AddCookie(&http.Cookie{
 | |
| 		Name:     cookieName,
 | |
| 		Value:    value,
 | |
| 		Path:     "/",
 | |
| 		Expires:  time.Now().Add(time.Duration(24)),
 | |
| 		HttpOnly: true,
 | |
| 	})
 | |
| 	req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
 | |
| 
 | |
| 	rw = httptest.NewRecorder()
 | |
| 	proxy.ServeHTTP(rw, req)
 | |
| 
 | |
| 	// The username in the basic auth credentials is expected to be equal to the email address from the
 | |
| 	// auth response, so we use the same variable here.
 | |
| 	expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword))
 | |
| 	assert.Equal(t, expectedHeader, rw.Body.String())
 | |
| 	providerServer.Close()
 | |
| }
 | |
| 
 | |
| func TestBasicAuthWithEmail(t *testing.T) {
 | |
| 	opts := baseTestOptions()
 | |
| 	opts.PassBasicAuth = true
 | |
| 	opts.PassUserHeaders = false
 | |
| 	opts.PreferEmailToUser = false
 | |
| 	opts.BasicAuthPassword = "This is a secure password"
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	const emailAddress = "john.doe@example.com"
 | |
| 	const userName = "9fcab5c9b889a557"
 | |
| 
 | |
| 	// The username in the basic auth credentials is expected to be equal to the email address from the
 | |
| 	expectedEmailHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword))
 | |
| 	expectedUserHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(userName+":"+opts.BasicAuthPassword))
 | |
| 
 | |
| 	created := time.Now()
 | |
| 	session := &sessions.SessionState{
 | |
| 		User:        userName,
 | |
| 		Email:       emailAddress,
 | |
| 		AccessToken: "oauth_token",
 | |
| 		CreatedAt:   &created,
 | |
| 	}
 | |
| 	{
 | |
| 		rw := httptest.NewRecorder()
 | |
| 		req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil)
 | |
| 		proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | |
| 			return email == emailAddress
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		proxy.addHeadersForProxying(rw, req, session)
 | |
| 		assert.Equal(t, expectedUserHeader, req.Header["Authorization"][0])
 | |
| 		assert.Equal(t, userName, req.Header["X-Forwarded-User"][0])
 | |
| 	}
 | |
| 
 | |
| 	opts.PreferEmailToUser = true
 | |
| 	{
 | |
| 		rw := httptest.NewRecorder()
 | |
| 		req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase1", nil)
 | |
| 
 | |
| 		proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | |
| 			return email == emailAddress
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		proxy.addHeadersForProxying(rw, req, session)
 | |
| 		assert.Equal(t, expectedEmailHeader, req.Header["Authorization"][0])
 | |
| 		assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0])
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestPassUserHeadersWithEmail(t *testing.T) {
 | |
| 	opts := baseTestOptions()
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	const emailAddress = "john.doe@example.com"
 | |
| 	const userName = "9fcab5c9b889a557"
 | |
| 
 | |
| 	created := time.Now()
 | |
| 	session := &sessions.SessionState{
 | |
| 		User:        userName,
 | |
| 		Email:       emailAddress,
 | |
| 		AccessToken: "oauth_token",
 | |
| 		CreatedAt:   &created,
 | |
| 	}
 | |
| 	{
 | |
| 		rw := httptest.NewRecorder()
 | |
| 		req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil)
 | |
| 		proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | |
| 			return email == emailAddress
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		proxy.addHeadersForProxying(rw, req, session)
 | |
| 		assert.Equal(t, userName, req.Header["X-Forwarded-User"][0])
 | |
| 	}
 | |
| 
 | |
| 	opts.PreferEmailToUser = true
 | |
| 	{
 | |
| 		rw := httptest.NewRecorder()
 | |
| 		req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase1", nil)
 | |
| 
 | |
| 		proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | |
| 			return email == emailAddress
 | |
| 		})
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		proxy.addHeadersForProxying(rw, req, session)
 | |
| 		assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0])
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestStripAuthHeaders(t *testing.T) {
 | |
| 	testCases := map[string]struct {
 | |
| 		SkipAuthStripHeaders bool
 | |
| 		PassBasicAuth        bool
 | |
| 		PassUserHeaders      bool
 | |
| 		PassAccessToken      bool
 | |
| 		PassAuthorization    bool
 | |
| 		StrippedHeaders      map[string]bool
 | |
| 	}{
 | |
| 		"Default options": {
 | |
| 			SkipAuthStripHeaders: true,
 | |
| 			PassBasicAuth:        true,
 | |
| 			PassUserHeaders:      true,
 | |
| 			PassAccessToken:      false,
 | |
| 			PassAuthorization:    false,
 | |
| 			StrippedHeaders: map[string]bool{
 | |
| 				"X-Forwarded-User":               true,
 | |
| 				"X-Forwarded-Email":              true,
 | |
| 				"X-Forwarded-Preferred-Username": true,
 | |
| 				"X-Forwarded-Access-Token":       false,
 | |
| 				"Authorization":                  true,
 | |
| 			},
 | |
| 		},
 | |
| 		"Pass access token": {
 | |
| 			SkipAuthStripHeaders: true,
 | |
| 			PassBasicAuth:        true,
 | |
| 			PassUserHeaders:      true,
 | |
| 			PassAccessToken:      true,
 | |
| 			PassAuthorization:    false,
 | |
| 			StrippedHeaders: map[string]bool{
 | |
| 				"X-Forwarded-User":               true,
 | |
| 				"X-Forwarded-Email":              true,
 | |
| 				"X-Forwarded-Preferred-Username": true,
 | |
| 				"X-Forwarded-Access-Token":       true,
 | |
| 				"Authorization":                  true,
 | |
| 			},
 | |
| 		},
 | |
| 		"Nothing setting Authorization": {
 | |
| 			SkipAuthStripHeaders: true,
 | |
| 			PassBasicAuth:        false,
 | |
| 			PassUserHeaders:      true,
 | |
| 			PassAccessToken:      true,
 | |
| 			PassAuthorization:    false,
 | |
| 			StrippedHeaders: map[string]bool{
 | |
| 				"X-Forwarded-User":               true,
 | |
| 				"X-Forwarded-Email":              true,
 | |
| 				"X-Forwarded-Preferred-Username": true,
 | |
| 				"X-Forwarded-Access-Token":       true,
 | |
| 				"Authorization":                  false,
 | |
| 			},
 | |
| 		},
 | |
| 		"Only Authorization header modified": {
 | |
| 			SkipAuthStripHeaders: true,
 | |
| 			PassBasicAuth:        false,
 | |
| 			PassUserHeaders:      false,
 | |
| 			PassAccessToken:      false,
 | |
| 			PassAuthorization:    true,
 | |
| 			StrippedHeaders: map[string]bool{
 | |
| 				"X-Forwarded-User":               false,
 | |
| 				"X-Forwarded-Email":              false,
 | |
| 				"X-Forwarded-Preferred-Username": false,
 | |
| 				"X-Forwarded-Access-Token":       false,
 | |
| 				"Authorization":                  true,
 | |
| 			},
 | |
| 		},
 | |
| 		"Don't strip any headers (default options)": {
 | |
| 			SkipAuthStripHeaders: false,
 | |
| 			PassBasicAuth:        true,
 | |
| 			PassUserHeaders:      true,
 | |
| 			PassAccessToken:      false,
 | |
| 			PassAuthorization:    false,
 | |
| 			StrippedHeaders: map[string]bool{
 | |
| 				"X-Forwarded-User":               false,
 | |
| 				"X-Forwarded-Email":              false,
 | |
| 				"X-Forwarded-Preferred-Username": false,
 | |
| 				"X-Forwarded-Access-Token":       false,
 | |
| 				"Authorization":                  false,
 | |
| 			},
 | |
| 		},
 | |
| 		"Don't strip any headers (custom options)": {
 | |
| 			SkipAuthStripHeaders: false,
 | |
| 			PassBasicAuth:        true,
 | |
| 			PassUserHeaders:      true,
 | |
| 			PassAccessToken:      true,
 | |
| 			PassAuthorization:    false,
 | |
| 			StrippedHeaders: map[string]bool{
 | |
| 				"X-Forwarded-User":               false,
 | |
| 				"X-Forwarded-Email":              false,
 | |
| 				"X-Forwarded-Preferred-Username": false,
 | |
| 				"X-Forwarded-Access-Token":       false,
 | |
| 				"Authorization":                  false,
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	initialHeaders := map[string]string{
 | |
| 		"X-Forwarded-User":               "9fcab5c9b889a557",
 | |
| 		"X-Forwarded-Email":              "john.doe@example.com",
 | |
| 		"X-Forwarded-Preferred-Username": "john.doe",
 | |
| 		"X-Forwarded-Access-Token":       "AccessToken",
 | |
| 		"Authorization":                  "bearer IDToken",
 | |
| 	}
 | |
| 
 | |
| 	for name, tc := range testCases {
 | |
| 		t.Run(name, func(t *testing.T) {
 | |
| 			opts := baseTestOptions()
 | |
| 			opts.SkipAuthStripHeaders = tc.SkipAuthStripHeaders
 | |
| 			opts.PassBasicAuth = tc.PassBasicAuth
 | |
| 			opts.PassUserHeaders = tc.PassUserHeaders
 | |
| 			opts.PassAccessToken = tc.PassAccessToken
 | |
| 			opts.PassAuthorization = tc.PassAuthorization
 | |
| 			err := validation.Validate(opts)
 | |
| 			assert.NoError(t, err)
 | |
| 
 | |
| 			req, _ := http.NewRequest("GET", fmt.Sprintf("%s/testCase", opts.ProxyPrefix), nil)
 | |
| 			for header, val := range initialHeaders {
 | |
| 				req.Header.Set(header, val)
 | |
| 			}
 | |
| 
 | |
| 			proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
 | |
| 			assert.NoError(t, err)
 | |
| 			if proxy.skipAuthStripHeaders {
 | |
| 				proxy.stripAuthHeaders(req)
 | |
| 			}
 | |
| 
 | |
| 			for header, stripped := range tc.StrippedHeaders {
 | |
| 				if stripped {
 | |
| 					assert.Equal(t, req.Header.Get(header), "")
 | |
| 				} else {
 | |
| 					assert.Equal(t, req.Header.Get(header), initialHeaders[header])
 | |
| 				}
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type PassAccessTokenTest struct {
 | |
| 	providerServer *httptest.Server
 | |
| 	proxy          *OAuthProxy
 | |
| 	opts           *options.Options
 | |
| }
 | |
| 
 | |
| type PassAccessTokenTestOptions struct {
 | |
| 	PassAccessToken bool
 | |
| 	ProxyUpstream   string
 | |
| }
 | |
| 
 | |
| func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) {
 | |
| 	patt := &PassAccessTokenTest{}
 | |
| 
 | |
| 	patt.providerServer = httptest.NewServer(
 | |
| 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 			var payload string
 | |
| 			switch r.URL.Path {
 | |
| 			case "/oauth/token":
 | |
| 				payload = `{"access_token": "my_auth_token"}`
 | |
| 			default:
 | |
| 				payload = r.Header.Get("X-Forwarded-Access-Token")
 | |
| 				if payload == "" {
 | |
| 					payload = "No access token found."
 | |
| 				}
 | |
| 			}
 | |
| 			w.WriteHeader(200)
 | |
| 			_, err := w.Write([]byte(payload))
 | |
| 			if err != nil {
 | |
| 				panic(err)
 | |
| 			}
 | |
| 		}))
 | |
| 
 | |
| 	patt.opts = baseTestOptions()
 | |
| 	patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL)
 | |
| 	if opts.ProxyUpstream != "" {
 | |
| 		patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream)
 | |
| 	}
 | |
| 	patt.opts.Cookie.Secure = false
 | |
| 	patt.opts.PassAccessToken = opts.PassAccessToken
 | |
| 	err := validation.Validate(patt.opts)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	providerURL, _ := url.Parse(patt.providerServer.URL)
 | |
| 	const emailAddress = "michael.bland@gsa.gov"
 | |
| 
 | |
| 	patt.opts.SetProvider(NewTestProvider(providerURL, emailAddress))
 | |
| 	patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
 | |
| 		return email == emailAddress
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return patt, nil
 | |
| }
 | |
| 
 | |
| func (patTest *PassAccessTokenTest) Close() {
 | |
| 	patTest.providerServer.Close()
 | |
| }
 | |
| 
 | |
| func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
 | |
| 	cookie string) {
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
 | |
| 		strings.NewReader(""))
 | |
| 	if err != nil {
 | |
| 		return 0, ""
 | |
| 	}
 | |
| 	req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
 | |
| 	patTest.proxy.ServeHTTP(rw, req)
 | |
| 	return rw.Code, rw.Header().Values("Set-Cookie")[1]
 | |
| }
 | |
| 
 | |
| // getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath
 | |
| // and cookie and returns body and status code.
 | |
| func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoint string) (httpCode int, accessToken string) {
 | |
| 	cookieName := patTest.proxy.CookieName
 | |
| 	var value string
 | |
| 	keyPrefix := cookieName + "="
 | |
| 
 | |
| 	for _, field := range strings.Split(cookie, "; ") {
 | |
| 		value = strings.TrimPrefix(field, keyPrefix)
 | |
| 		if value != field {
 | |
| 			break
 | |
| 		} else {
 | |
| 			value = ""
 | |
| 		}
 | |
| 	}
 | |
| 	if value == "" {
 | |
| 		return 0, ""
 | |
| 	}
 | |
| 
 | |
| 	req, err := http.NewRequest("GET", endpoint, strings.NewReader(""))
 | |
| 	if err != nil {
 | |
| 		return 0, ""
 | |
| 	}
 | |
| 	req.AddCookie(&http.Cookie{
 | |
| 		Name:     cookieName,
 | |
| 		Value:    value,
 | |
| 		Path:     "/",
 | |
| 		Expires:  time.Now().Add(time.Duration(24)),
 | |
| 		HttpOnly: true,
 | |
| 	})
 | |
| 
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	patTest.proxy.ServeHTTP(rw, req)
 | |
| 	return rw.Code, rw.Body.String()
 | |
| }
 | |
| 
 | |
| func TestForwardAccessTokenUpstream(t *testing.T) {
 | |
| 	patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
 | |
| 		PassAccessToken: true,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	t.Cleanup(patTest.Close)
 | |
| 
 | |
| 	// A successful validation will redirect and set the auth cookie.
 | |
| 	code, cookie := patTest.getCallbackEndpoint()
 | |
| 	if code != 302 {
 | |
| 		t.Fatalf("expected 302; got %d", code)
 | |
| 	}
 | |
| 	assert.NotNil(t, cookie)
 | |
| 
 | |
| 	// Now we make a regular request; the access_token from the cookie is
 | |
| 	// forwarded as the "X-Forwarded-Access-Token" header. The token is
 | |
| 	// read by the test provider server and written in the response body.
 | |
| 	code, payload := patTest.getEndpointWithCookie(cookie, "/")
 | |
| 	if code != 200 {
 | |
| 		t.Fatalf("expected 200; got %d", code)
 | |
| 	}
 | |
| 	assert.Equal(t, "my_auth_token", payload)
 | |
| }
 | |
| 
 | |
| func TestStaticProxyUpstream(t *testing.T) {
 | |
| 	patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
 | |
| 		PassAccessToken: true,
 | |
| 		ProxyUpstream:   "static://200/static-proxy",
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	t.Cleanup(patTest.Close)
 | |
| 
 | |
| 	// A successful validation will redirect and set the auth cookie.
 | |
| 	code, cookie := patTest.getCallbackEndpoint()
 | |
| 	if code != 302 {
 | |
| 		t.Fatalf("expected 302; got %d", code)
 | |
| 	}
 | |
| 	assert.NotEqual(t, nil, cookie)
 | |
| 
 | |
| 	// Now we make a regular request against the upstream proxy; And validate
 | |
| 	// the returned status code through the static proxy.
 | |
| 	code, payload := patTest.getEndpointWithCookie(cookie, "/static-proxy")
 | |
| 	if code != 200 {
 | |
| 		t.Fatalf("expected 200; got %d", code)
 | |
| 	}
 | |
| 	assert.Equal(t, "Authenticated", payload)
 | |
| }
 | |
| 
 | |
| func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
 | |
| 	patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{
 | |
| 		PassAccessToken: false,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	t.Cleanup(patTest.Close)
 | |
| 
 | |
| 	// A successful validation will redirect and set the auth cookie.
 | |
| 	code, cookie := patTest.getCallbackEndpoint()
 | |
| 	if code != 302 {
 | |
| 		t.Fatalf("expected 302; got %d", code)
 | |
| 	}
 | |
| 	assert.NotEqual(t, nil, cookie)
 | |
| 
 | |
| 	// Now we make a regular request, but the access token header should
 | |
| 	// not be present.
 | |
| 	code, payload := patTest.getEndpointWithCookie(cookie, "/")
 | |
| 	if code != 200 {
 | |
| 		t.Fatalf("expected 200; got %d", code)
 | |
| 	}
 | |
| 	assert.Equal(t, "No access token found.", payload)
 | |
| }
 | |
| 
 | |
| type SignInPageTest struct {
 | |
| 	opts                 *options.Options
 | |
| 	proxy                *OAuthProxy
 | |
| 	signInRegexp         *regexp.Regexp
 | |
| 	signInProviderRegexp *regexp.Regexp
 | |
| }
 | |
| 
 | |
| const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
 | |
| const signInSkipProvider = `>Found<`
 | |
| 
 | |
| func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) {
 | |
| 	var sipTest SignInPageTest
 | |
| 
 | |
| 	sipTest.opts = baseTestOptions()
 | |
| 	sipTest.opts.SkipProviderButton = skipProvider
 | |
| 	err := validation.Validate(sipTest.opts)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool {
 | |
| 		return true
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
 | |
| 	sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
 | |
| 
 | |
| 	return &sipTest, nil
 | |
| }
 | |
| 
 | |
| func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
 | |
| 	sipTest.proxy.ServeHTTP(rw, req)
 | |
| 	return rw.Code, rw.Body.String()
 | |
| }
 | |
| 
 | |
| func TestSignInPageIncludesTargetRedirect(t *testing.T) {
 | |
| 	sipTest, err := NewSignInPageTest(false)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	const endpoint = "/some/random/endpoint"
 | |
| 
 | |
| 	code, body := sipTest.GetEndpoint(endpoint)
 | |
| 	assert.Equal(t, 403, code)
 | |
| 
 | |
| 	match := sipTest.signInRegexp.FindStringSubmatch(body)
 | |
| 	if match == nil {
 | |
| 		t.Fatal("Did not find pattern in body: " +
 | |
| 			signInRedirectPattern + "\nBody:\n" + body)
 | |
| 	}
 | |
| 	if match[1] != endpoint {
 | |
| 		t.Fatal(`expected redirect to "` + endpoint +
 | |
| 			`", but was "` + match[1] + `"`)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
 | |
| 	sipTest, err := NewSignInPageTest(false)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	code, body := sipTest.GetEndpoint("/oauth2/sign_in")
 | |
| 	assert.Equal(t, 200, code)
 | |
| 
 | |
| 	match := sipTest.signInRegexp.FindStringSubmatch(body)
 | |
| 	if match == nil {
 | |
| 		t.Fatal("Did not find pattern in body: " +
 | |
| 			signInRedirectPattern + "\nBody:\n" + body)
 | |
| 	}
 | |
| 	if match[1] != "/" {
 | |
| 		t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestSignInPageSkipProvider(t *testing.T) {
 | |
| 	sipTest, err := NewSignInPageTest(true)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	endpoint := "/some/random/endpoint"
 | |
| 
 | |
| 	code, body := sipTest.GetEndpoint(endpoint)
 | |
| 	assert.Equal(t, 302, code)
 | |
| 
 | |
| 	match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
 | |
| 	if match == nil {
 | |
| 		t.Fatal("Did not find pattern in body: " +
 | |
| 			signInSkipProvider + "\nBody:\n" + body)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestSignInPageSkipProviderDirect(t *testing.T) {
 | |
| 	sipTest, err := NewSignInPageTest(true)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	endpoint := "/sign_in"
 | |
| 
 | |
| 	code, body := sipTest.GetEndpoint(endpoint)
 | |
| 	assert.Equal(t, 302, code)
 | |
| 
 | |
| 	match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
 | |
| 	if match == nil {
 | |
| 		t.Fatal("Did not find pattern in body: " +
 | |
| 			signInSkipProvider + "\nBody:\n" + body)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type ProcessCookieTest struct {
 | |
| 	opts         *options.Options
 | |
| 	proxy        *OAuthProxy
 | |
| 	rw           *httptest.ResponseRecorder
 | |
| 	req          *http.Request
 | |
| 	validateUser bool
 | |
| }
 | |
| 
 | |
| type ProcessCookieTestOpts struct {
 | |
| 	providerValidateCookieResponse bool
 | |
| }
 | |
| 
 | |
| type OptionsModifier func(*options.Options)
 | |
| 
 | |
| func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
 | |
| 	var pcTest ProcessCookieTest
 | |
| 
 | |
| 	pcTest.opts = baseTestOptions()
 | |
| 	for _, modifier := range modifiers {
 | |
| 		modifier(pcTest.opts)
 | |
| 	}
 | |
| 	// First, set the CookieRefresh option so proxy.AesCipher is created,
 | |
| 	// needed to encrypt the access_token.
 | |
| 	pcTest.opts.Cookie.Refresh = time.Hour
 | |
| 	err := validation.Validate(pcTest.opts)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
 | |
| 		return pcTest.validateUser
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	pcTest.proxy.provider = &TestProvider{
 | |
| 		ValidToken: opts.providerValidateCookieResponse,
 | |
| 	}
 | |
| 
 | |
| 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | |
| 	// access_token validation.
 | |
| 	pcTest.proxy.CookieRefresh = time.Duration(0)
 | |
| 	pcTest.rw = httptest.NewRecorder()
 | |
| 	pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
 | |
| 	pcTest.validateUser = true
 | |
| 	return &pcTest, nil
 | |
| }
 | |
| 
 | |
| func NewProcessCookieTestWithDefaults() (*ProcessCookieTest, error) {
 | |
| 	return NewProcessCookieTest(ProcessCookieTestOpts{
 | |
| 		providerValidateCookieResponse: true,
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
 | |
| 	return NewProcessCookieTest(ProcessCookieTestOpts{
 | |
| 		providerValidateCookieResponse: true,
 | |
| 	}, modifiers...)
 | |
| }
 | |
| 
 | |
| func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
 | |
| 	err := p.proxy.SaveSession(p.rw, p.req, s)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	for _, cookie := range p.rw.Result().Cookies() {
 | |
| 		p.req.AddCookie(cookie)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
 | |
| 	return p.proxy.LoadCookiedSession(p.req)
 | |
| }
 | |
| 
 | |
| func TestLoadCookiedSession(t *testing.T) {
 | |
| 	pcTest, err := NewProcessCookieTestWithDefaults()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	created := time.Now()
 | |
| 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created}
 | |
| 	err = pcTest.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	session, err := pcTest.LoadCookiedSession()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	assert.Equal(t, startSession.Email, session.Email)
 | |
| 	assert.Equal(t, "", session.User)
 | |
| 	assert.Equal(t, startSession.AccessToken, session.AccessToken)
 | |
| }
 | |
| 
 | |
| func TestProcessCookieNoCookieError(t *testing.T) {
 | |
| 	pcTest, err := NewProcessCookieTestWithDefaults()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	session, err := pcTest.LoadCookiedSession()
 | |
| 	assert.Error(t, err, "cookie \"_oauth2_proxy\" not present")
 | |
| 	if session != nil {
 | |
| 		t.Errorf("expected nil session. got %#v", session)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessCookieRefreshNotSet(t *testing.T) {
 | |
| 	pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
 | |
| 		opts.Cookie.Expire = time.Duration(23) * time.Hour
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour)
 | |
| 
 | |
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
 | |
| 	err = pcTest.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	session, err := pcTest.LoadCookiedSession()
 | |
| 	assert.Equal(t, nil, err)
 | |
| 	if session.Age() < time.Duration(-2)*time.Hour {
 | |
| 		t.Errorf("cookie too young %v", session.Age())
 | |
| 	}
 | |
| 	assert.Equal(t, startSession.Email, session.Email)
 | |
| }
 | |
| 
 | |
| func TestProcessCookieFailIfCookieExpired(t *testing.T) {
 | |
| 	pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
 | |
| 		opts.Cookie.Expire = time.Duration(24) * time.Hour
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
 | |
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
 | |
| 	err = pcTest.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	session, err := pcTest.LoadCookiedSession()
 | |
| 	assert.NotEqual(t, nil, err)
 | |
| 	if session != nil {
 | |
| 		t.Errorf("expected nil session %#v", session)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
 | |
| 	pcTest, err := NewProcessCookieTestWithOptionsModifiers(func(opts *options.Options) {
 | |
| 		opts.Cookie.Expire = time.Duration(24) * time.Hour
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
 | |
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
 | |
| 	err = pcTest.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	pcTest.proxy.CookieRefresh = time.Hour
 | |
| 	session, err := pcTest.LoadCookiedSession()
 | |
| 	assert.NotEqual(t, nil, err)
 | |
| 	if session != nil {
 | |
| 		t.Errorf("expected nil session %#v", session)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func NewUserInfoEndpointTest() (*ProcessCookieTest, error) {
 | |
| 	pcTest, err := NewProcessCookieTestWithDefaults()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	pcTest.req, _ = http.NewRequest("GET",
 | |
| 		pcTest.opts.ProxyPrefix+"/userinfo", nil)
 | |
| 	return pcTest, nil
 | |
| }
 | |
| 
 | |
| func TestUserInfoEndpointAccepted(t *testing.T) {
 | |
| 	test, err := NewUserInfoEndpointTest()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	startSession := &sessions.SessionState{
 | |
| 		Email: "john.doe@example.com", AccessToken: "my_access_token"}
 | |
| 	err = test.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	test.proxy.ServeHTTP(test.rw, test.req)
 | |
| 	assert.Equal(t, http.StatusOK, test.rw.Code)
 | |
| 	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
 | |
| 	assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes))
 | |
| }
 | |
| 
 | |
| func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
 | |
| 	test, err := NewUserInfoEndpointTest()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	test.proxy.ServeHTTP(test.rw, test.req)
 | |
| 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
 | |
| }
 | |
| 
 | |
| func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
 | |
| 	pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	pcTest.req, _ = http.NewRequest("GET",
 | |
| 		pcTest.opts.ProxyPrefix+"/auth", nil)
 | |
| 	return pcTest, nil
 | |
| }
 | |
| 
 | |
| func TestAuthOnlyEndpointAccepted(t *testing.T) {
 | |
| 	test, err := NewAuthOnlyEndpointTest()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	created := time.Now()
 | |
| 	startSession := &sessions.SessionState{
 | |
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
 | |
| 	err = test.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	test.proxy.ServeHTTP(test.rw, test.req)
 | |
| 	assert.Equal(t, http.StatusAccepted, test.rw.Code)
 | |
| 	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
 | |
| 	assert.Equal(t, "", string(bodyBytes))
 | |
| }
 | |
| 
 | |
| func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
 | |
| 	test, err := NewAuthOnlyEndpointTest()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	test.proxy.ServeHTTP(test.rw, test.req)
 | |
| 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
 | |
| 	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
 | |
| 	assert.Equal(t, "unauthorized request\n", string(bodyBytes))
 | |
| }
 | |
| 
 | |
| func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
 | |
| 	test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
 | |
| 		opts.Cookie.Expire = time.Duration(24) * time.Hour
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
 | |
| 	startSession := &sessions.SessionState{
 | |
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
 | |
| 	err = test.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	test.proxy.ServeHTTP(test.rw, test.req)
 | |
| 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
 | |
| 	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
 | |
| 	assert.Equal(t, "unauthorized request\n", string(bodyBytes))
 | |
| }
 | |
| 
 | |
| func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
 | |
| 	test, err := NewAuthOnlyEndpointTest()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	created := time.Now()
 | |
| 	startSession := &sessions.SessionState{
 | |
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
 | |
| 	err = test.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 	test.validateUser = false
 | |
| 
 | |
| 	test.proxy.ServeHTTP(test.rw, test.req)
 | |
| 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
 | |
| 	bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
 | |
| 	assert.Equal(t, "unauthorized request\n", string(bodyBytes))
 | |
| }
 | |
| 
 | |
| func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
 | |
| 	var pcTest ProcessCookieTest
 | |
| 
 | |
| 	pcTest.opts = baseTestOptions()
 | |
| 	pcTest.opts.SetXAuthRequest = true
 | |
| 	err := validation.Validate(pcTest.opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
 | |
| 		return pcTest.validateUser
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	pcTest.proxy.provider = &TestProvider{
 | |
| 		ValidToken: true,
 | |
| 	}
 | |
| 
 | |
| 	pcTest.validateUser = true
 | |
| 
 | |
| 	pcTest.rw = httptest.NewRecorder()
 | |
| 	pcTest.req, _ = http.NewRequest("GET",
 | |
| 		pcTest.opts.ProxyPrefix+"/auth", nil)
 | |
| 
 | |
| 	created := time.Now()
 | |
| 	startSession := &sessions.SessionState{
 | |
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
 | |
| 	err = pcTest.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
 | |
| 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
 | |
| 	assert.Equal(t, "oauth_user", pcTest.rw.Header().Get("X-Auth-Request-User"))
 | |
| 	assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Get("X-Auth-Request-Email"))
 | |
| }
 | |
| 
 | |
| func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
 | |
| 	var pcTest ProcessCookieTest
 | |
| 
 | |
| 	pcTest.opts = baseTestOptions()
 | |
| 	pcTest.opts.SetXAuthRequest = true
 | |
| 	pcTest.opts.SetBasicAuth = true
 | |
| 	err := validation.Validate(pcTest.opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
 | |
| 		return pcTest.validateUser
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	pcTest.proxy.provider = &TestProvider{
 | |
| 		ValidToken: true,
 | |
| 	}
 | |
| 
 | |
| 	pcTest.validateUser = true
 | |
| 
 | |
| 	pcTest.rw = httptest.NewRecorder()
 | |
| 	pcTest.req, _ = http.NewRequest("GET",
 | |
| 		pcTest.opts.ProxyPrefix+"/auth", nil)
 | |
| 
 | |
| 	created := time.Now()
 | |
| 	startSession := &sessions.SessionState{
 | |
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
 | |
| 	err = pcTest.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
 | |
| 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
 | |
| 	assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
 | |
| 	assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
 | |
| 	expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:"+pcTest.opts.BasicAuthPassword))
 | |
| 	assert.Equal(t, expectedHeader, pcTest.rw.Header().Values("Authorization")[0])
 | |
| }
 | |
| 
 | |
| func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
 | |
| 	var pcTest ProcessCookieTest
 | |
| 
 | |
| 	pcTest.opts = baseTestOptions()
 | |
| 	pcTest.opts.SetXAuthRequest = true
 | |
| 	pcTest.opts.SetBasicAuth = false
 | |
| 	err := validation.Validate(pcTest.opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
 | |
| 		return pcTest.validateUser
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	pcTest.proxy.provider = &TestProvider{
 | |
| 		ValidToken: true,
 | |
| 	}
 | |
| 
 | |
| 	pcTest.validateUser = true
 | |
| 
 | |
| 	pcTest.rw = httptest.NewRecorder()
 | |
| 	pcTest.req, _ = http.NewRequest("GET",
 | |
| 		pcTest.opts.ProxyPrefix+"/auth", nil)
 | |
| 
 | |
| 	created := time.Now()
 | |
| 	startSession := &sessions.SessionState{
 | |
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
 | |
| 	err = pcTest.SaveSession(startSession)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
 | |
| 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
 | |
| 	assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0])
 | |
| 	assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0])
 | |
| 	assert.Equal(t, 0, len(pcTest.rw.Header().Values("Authorization")), "should not have Authorization header entries")
 | |
| }
 | |
| 
 | |
| func TestAuthSkippedForPreflightRequests(t *testing.T) {
 | |
| 	upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		w.WriteHeader(200)
 | |
| 		_, err := w.Write([]byte("response"))
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 	}))
 | |
| 	t.Cleanup(upstream.Close)
 | |
| 
 | |
| 	opts := baseTestOptions()
 | |
| 	opts.Upstreams = append(opts.Upstreams, upstream.URL)
 | |
| 	opts.SkipAuthPreflight = true
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	upstreamURL, _ := url.Parse(upstream.URL)
 | |
| 	opts.SetProvider(NewTestProvider(upstreamURL, ""))
 | |
| 
 | |
| 	proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
 | |
| 	proxy.ServeHTTP(rw, req)
 | |
| 
 | |
| 	assert.Equal(t, 200, rw.Code)
 | |
| 	assert.Equal(t, "response", rw.Body.String())
 | |
| }
 | |
| 
 | |
| type SignatureAuthenticator struct {
 | |
| 	auth hmacauth.HmacAuth
 | |
| }
 | |
| 
 | |
| func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) {
 | |
| 	result, headerSig, computedSig := v.auth.AuthenticateRequest(r)
 | |
| 
 | |
| 	var msg string
 | |
| 	switch result {
 | |
| 	case hmacauth.ResultNoSignature:
 | |
| 		msg = "no signature received"
 | |
| 	case hmacauth.ResultMatch:
 | |
| 		msg = "signatures match"
 | |
| 	case hmacauth.ResultMismatch:
 | |
| 		msg = fmt.Sprintf(
 | |
| 			"signatures do not match:\n  received: %s\n  computed: %s",
 | |
| 			headerSig,
 | |
| 			computedSig)
 | |
| 	default:
 | |
| 		panic("unknown result value: " + result.String())
 | |
| 	}
 | |
| 
 | |
| 	_, err := w.Write([]byte(msg))
 | |
| 	if err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type SignatureTest struct {
 | |
| 	opts          *options.Options
 | |
| 	upstream      *httptest.Server
 | |
| 	upstreamHost  string
 | |
| 	provider      *httptest.Server
 | |
| 	header        http.Header
 | |
| 	rw            *httptest.ResponseRecorder
 | |
| 	authenticator *SignatureAuthenticator
 | |
| }
 | |
| 
 | |
| func NewSignatureTest() (*SignatureTest, error) {
 | |
| 	opts := baseTestOptions()
 | |
| 	opts.EmailDomains = []string{"acm.org"}
 | |
| 
 | |
| 	authenticator := &SignatureAuthenticator{}
 | |
| 	upstream := httptest.NewServer(
 | |
| 		http.HandlerFunc(authenticator.Authenticate))
 | |
| 	upstreamURL, err := url.Parse(upstream.URL)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	opts.Upstreams = append(opts.Upstreams, upstream.URL)
 | |
| 
 | |
| 	providerHandler := func(w http.ResponseWriter, r *http.Request) {
 | |
| 		_, err := w.Write([]byte(`{"access_token": "my_auth_token"}`))
 | |
| 		if err != nil {
 | |
| 			panic(err)
 | |
| 		}
 | |
| 	}
 | |
| 	provider := httptest.NewServer(http.HandlerFunc(providerHandler))
 | |
| 	providerURL, err := url.Parse(provider.URL)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org"))
 | |
| 
 | |
| 	return &SignatureTest{
 | |
| 		opts,
 | |
| 		upstream,
 | |
| 		upstreamURL.Host,
 | |
| 		provider,
 | |
| 		make(http.Header),
 | |
| 		httptest.NewRecorder(),
 | |
| 		authenticator,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (st *SignatureTest) Close() {
 | |
| 	st.provider.Close()
 | |
| 	st.upstream.Close()
 | |
| }
 | |
| 
 | |
| // fakeNetConn simulates an http.Request.Body buffer that will be consumed
 | |
| // when it is read by the hmacauth.HmacAuth if not handled properly. See:
 | |
| //   https://github.com/18F/hmacauth/pull/4
 | |
| type fakeNetConn struct {
 | |
| 	reqBody string
 | |
| }
 | |
| 
 | |
| func (fnc *fakeNetConn) Read(p []byte) (n int, err error) {
 | |
| 	if bodyLen := len(fnc.reqBody); bodyLen != 0 {
 | |
| 		copy(p, fnc.reqBody)
 | |
| 		fnc.reqBody = ""
 | |
| 		return bodyLen, io.EOF
 | |
| 	}
 | |
| 	return 0, io.EOF
 | |
| }
 | |
| 
 | |
| func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) error {
 | |
| 	err := validation.Validate(st.opts)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true })
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	var bodyBuf io.ReadCloser
 | |
| 	if body != "" {
 | |
| 		bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body})
 | |
| 	}
 | |
| 	req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
 | |
| 	req.Header = st.header
 | |
| 
 | |
| 	state := &sessions.SessionState{
 | |
| 		Email: "mbland@acm.org", AccessToken: "my_access_token"}
 | |
| 	err = proxy.SaveSession(st.rw, req, state)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	for _, c := range st.rw.Result().Cookies() {
 | |
| 		req.AddCookie(c)
 | |
| 	}
 | |
| 	// This is used by the upstream to validate the signature.
 | |
| 	st.authenticator.auth = hmacauth.NewHmacAuth(
 | |
| 		crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
 | |
| 	proxy.ServeHTTP(st.rw, req)
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func TestRequestSignature(t *testing.T) {
 | |
| 	testCases := map[string]struct {
 | |
| 		method string
 | |
| 		body   string
 | |
| 		key    string
 | |
| 		resp   string
 | |
| 	}{
 | |
| 		"No request signature": {
 | |
| 			method: "GET",
 | |
| 			body:   "",
 | |
| 			key:    "",
 | |
| 			resp:   "no signature received",
 | |
| 		},
 | |
| 		"Get request": {
 | |
| 			method: "GET",
 | |
| 			body:   "",
 | |
| 			key:    "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d",
 | |
| 			resp:   "signatures match",
 | |
| 		},
 | |
| 		"Post request": {
 | |
| 			method: "POST",
 | |
| 			body:   `{ "hello": "world!" }`,
 | |
| 			key:    "d90df39e2d19282840252612dd7c81421a372f61",
 | |
| 			resp:   "signatures match",
 | |
| 		},
 | |
| 	}
 | |
| 	for name, tc := range testCases {
 | |
| 		t.Run(name, func(t *testing.T) {
 | |
| 			st, err := NewSignatureTest()
 | |
| 			if err != nil {
 | |
| 				t.Fatal(err)
 | |
| 			}
 | |
| 			t.Cleanup(st.Close)
 | |
| 			if tc.key != "" {
 | |
| 				st.opts.SignatureKey = fmt.Sprintf("sha1:%s", tc.key)
 | |
| 			}
 | |
| 			err = st.MakeRequestWithExpectedKey(tc.method, tc.body, tc.key)
 | |
| 			assert.NoError(t, err)
 | |
| 			assert.Equal(t, 200, st.rw.Code)
 | |
| 			assert.Equal(t, tc.resp, st.rw.Body.String())
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGetRedirect(t *testing.T) {
 | |
| 	opts := baseTestOptions()
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 	require.NotEmpty(t, opts.ProxyPrefix)
 | |
| 	proxy, err := NewOAuthProxy(opts, func(s string) bool { return false })
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	tests := []struct {
 | |
| 		name             string
 | |
| 		url              string
 | |
| 		expectedRedirect string
 | |
| 	}{
 | |
| 		{
 | |
| 			name:             "request outside of ProxyPrefix redirects to original URL",
 | |
| 			url:              "/foo/bar",
 | |
| 			expectedRedirect: "/foo/bar",
 | |
| 		},
 | |
| 		{
 | |
| 			name:             "request under ProxyPrefix redirects to root",
 | |
| 			url:              proxy.ProxyPrefix + "/foo/bar",
 | |
| 			expectedRedirect: "/",
 | |
| 		},
 | |
| 	}
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.name, func(t *testing.T) {
 | |
| 			req, _ := http.NewRequest("GET", tt.url, nil)
 | |
| 			redirect, err := proxy.GetRedirect(req)
 | |
| 
 | |
| 			assert.NoError(t, err)
 | |
| 			assert.Equal(t, tt.expectedRedirect, redirect)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type ajaxRequestTest struct {
 | |
| 	opts  *options.Options
 | |
| 	proxy *OAuthProxy
 | |
| }
 | |
| 
 | |
| func newAjaxRequestTest() (*ajaxRequestTest, error) {
 | |
| 	test := &ajaxRequestTest{}
 | |
| 	test.opts = baseTestOptions()
 | |
| 	err := validation.Validate(test.opts)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool {
 | |
| 		return true
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return test, nil
 | |
| }
 | |
| 
 | |
| func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) {
 | |
| 	rw := httptest.NewRecorder()
 | |
| 	req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader(""))
 | |
| 	if err != nil {
 | |
| 		return 0, nil, err
 | |
| 	}
 | |
| 	req.Header = header
 | |
| 	test.proxy.ServeHTTP(rw, req)
 | |
| 	return rw.Code, rw.Header(), nil
 | |
| }
 | |
| 
 | |
| func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) {
 | |
| 	test, err := newAjaxRequestTest()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	endpoint := "/test"
 | |
| 
 | |
| 	code, rh, err := test.getEndpoint(endpoint, header)
 | |
| 	assert.NoError(t, err)
 | |
| 	assert.Equal(t, http.StatusUnauthorized, code)
 | |
| 	mime := rh.Get("Content-Type")
 | |
| 	assert.Equal(t, applicationJSON, mime)
 | |
| }
 | |
| func TestAjaxUnauthorizedRequest1(t *testing.T) {
 | |
| 	header := make(http.Header)
 | |
| 	header.Add("accept", applicationJSON)
 | |
| 
 | |
| 	testAjaxUnauthorizedRequest(t, header)
 | |
| }
 | |
| 
 | |
| func TestAjaxUnauthorizedRequest2(t *testing.T) {
 | |
| 	header := make(http.Header)
 | |
| 	header.Add("Accept", applicationJSON)
 | |
| 
 | |
| 	testAjaxUnauthorizedRequest(t, header)
 | |
| }
 | |
| 
 | |
| func TestAjaxForbiddendRequest(t *testing.T) {
 | |
| 	test, err := newAjaxRequestTest()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	endpoint := "/test"
 | |
| 	header := make(http.Header)
 | |
| 	code, rh, err := test.getEndpoint(endpoint, header)
 | |
| 	assert.NoError(t, err)
 | |
| 	assert.Equal(t, http.StatusForbidden, code)
 | |
| 	mime := rh.Get("Content-Type")
 | |
| 	assert.NotEqual(t, applicationJSON, mime)
 | |
| }
 | |
| 
 | |
| func TestClearSplitCookie(t *testing.T) {
 | |
| 	opts := baseTestOptions()
 | |
| 	opts.Cookie.Secret = base64CookieSecret
 | |
| 	opts.Cookie.Name = "oauth2"
 | |
| 	opts.Cookie.Domains = []string{"abc"}
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 
 | |
| 	store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store}
 | |
| 	var rw = httptest.NewRecorder()
 | |
| 	req := httptest.NewRequest("get", "/", nil)
 | |
| 
 | |
| 	req.AddCookie(&http.Cookie{
 | |
| 		Name:  "test1",
 | |
| 		Value: "test1",
 | |
| 	})
 | |
| 	req.AddCookie(&http.Cookie{
 | |
| 		Name:  "oauth2_0",
 | |
| 		Value: "oauth2_0",
 | |
| 	})
 | |
| 	req.AddCookie(&http.Cookie{
 | |
| 		Name:  "oauth2_1",
 | |
| 		Value: "oauth2_1",
 | |
| 	})
 | |
| 
 | |
| 	err = p.ClearSessionCookie(rw, req)
 | |
| 	assert.NoError(t, err)
 | |
| 	header := rw.Header()
 | |
| 
 | |
| 	assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries")
 | |
| }
 | |
| 
 | |
| func TestClearSingleCookie(t *testing.T) {
 | |
| 	opts := baseTestOptions()
 | |
| 	opts.Cookie.Name = "oauth2"
 | |
| 	opts.Cookie.Domains = []string{"abc"}
 | |
| 	store, err := sessionscookie.NewCookieSessionStore(&opts.Session, &opts.Cookie)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store}
 | |
| 	var rw = httptest.NewRecorder()
 | |
| 	req := httptest.NewRequest("get", "/", nil)
 | |
| 
 | |
| 	req.AddCookie(&http.Cookie{
 | |
| 		Name:  "test1",
 | |
| 		Value: "test1",
 | |
| 	})
 | |
| 	req.AddCookie(&http.Cookie{
 | |
| 		Name:  "oauth2",
 | |
| 		Value: "oauth2",
 | |
| 	})
 | |
| 
 | |
| 	err = p.ClearSessionCookie(rw, req)
 | |
| 	assert.NoError(t, err)
 | |
| 	header := rw.Header()
 | |
| 
 | |
| 	assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries")
 | |
| }
 | |
| 
 | |
| type NoOpKeySet struct {
 | |
| }
 | |
| 
 | |
| func (NoOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
 | |
| 	splitStrings := strings.Split(jwt, ".")
 | |
| 	payloadString := splitStrings[1]
 | |
| 	return base64.RawURLEncoding.DecodeString(payloadString)
 | |
| }
 | |
| 
 | |
| func TestGetJwtSession(t *testing.T) {
 | |
| 	/* token payload:
 | |
| 	{
 | |
| 	  "sub": "1234567890",
 | |
| 	  "aud": "https://test.myapp.com",
 | |
| 	  "name": "John Doe",
 | |
| 	  "email": "john@example.com",
 | |
| 	  "iss": "https://issuer.example.com",
 | |
| 	  "iat": 1553691215,
 | |
| 	  "exp": 1912151821
 | |
| 	}
 | |
| 	*/
 | |
| 	goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
 | |
| 		"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" +
 | |
| 		"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" +
 | |
| 		"E1LCJleHAiOjE5MTIxNTE4MjF9." +
 | |
| 		"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" +
 | |
| 		"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8"
 | |
| 
 | |
| 	keyset := NoOpKeySet{}
 | |
| 	verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
 | |
| 		&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
 | |
| 
 | |
| 	test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
 | |
| 		opts.PassAuthorization = true
 | |
| 		opts.SetAuthorization = true
 | |
| 		opts.SetXAuthRequest = true
 | |
| 		opts.SkipJwtBearerTokens = true
 | |
| 		opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier))
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	tp, _ := test.proxy.provider.(*TestProvider)
 | |
| 	tp.GroupValidator = func(s string) bool {
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	authHeader := fmt.Sprintf("Bearer %s", goodJwt)
 | |
| 	test.req.Header = map[string][]string{
 | |
| 		"Authorization": {authHeader},
 | |
| 	}
 | |
| 
 | |
| 	// Bearer
 | |
| 	expires := time.Unix(1912151821, 0)
 | |
| 	session, err := test.proxy.GetJwtSession(test.req)
 | |
| 	assert.NoError(t, err)
 | |
| 	assert.Equal(t, session.User, "1234567890")
 | |
| 	assert.Equal(t, session.Email, "john@example.com")
 | |
| 	assert.Equal(t, session.ExpiresOn, &expires)
 | |
| 	assert.Equal(t, session.IDToken, goodJwt)
 | |
| 
 | |
| 	test.proxy.ServeHTTP(test.rw, test.req)
 | |
| 	if test.rw.Code >= 400 {
 | |
| 		t.Fatalf("expected 3xx got %d", test.rw.Code)
 | |
| 	}
 | |
| 
 | |
| 	// Check PassAuthorization, should overwrite Basic header
 | |
| 	assert.Equal(t, test.req.Header.Get("Authorization"), authHeader)
 | |
| 	assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "1234567890")
 | |
| 	assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com")
 | |
| 
 | |
| 	// SetAuthorization and SetXAuthRequest
 | |
| 	assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader)
 | |
| 	assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "1234567890")
 | |
| 	assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
 | |
| }
 | |
| 
 | |
| func TestFindJwtBearerToken(t *testing.T) {
 | |
| 	p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}}
 | |
| 	getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}
 | |
| 
 | |
| 	validToken := "eyJfoobar.eyJfoobar.12345asdf"
 | |
| 	var token string
 | |
| 
 | |
| 	// Bearer
 | |
| 	getReq.Header = map[string][]string{
 | |
| 		"Authorization": {fmt.Sprintf("Bearer %s", validToken)},
 | |
| 	}
 | |
| 
 | |
| 	token, err := p.findBearerToken(getReq)
 | |
| 	assert.NoError(t, err)
 | |
| 	assert.Equal(t, validToken, token)
 | |
| 
 | |
| 	// Basic - no password
 | |
| 	getReq.SetBasicAuth(token, "")
 | |
| 	token, err = p.findBearerToken(getReq)
 | |
| 	assert.NoError(t, err)
 | |
| 	assert.Equal(t, validToken, token)
 | |
| 
 | |
| 	// Basic - sentinel password
 | |
| 	getReq.SetBasicAuth(token, "x-oauth-basic")
 | |
| 	token, err = p.findBearerToken(getReq)
 | |
| 	assert.NoError(t, err)
 | |
| 	assert.Equal(t, validToken, token)
 | |
| 
 | |
| 	// Basic - any username, password matching jwt pattern
 | |
| 	getReq.SetBasicAuth("any-username-you-could-wish-for", token)
 | |
| 	token, err = p.findBearerToken(getReq)
 | |
| 	assert.NoError(t, err)
 | |
| 	assert.Equal(t, validToken, token)
 | |
| 
 | |
| 	failures := []string{
 | |
| 		// Too many parts
 | |
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
 | |
| 		// Not enough parts
 | |
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA",
 | |
| 		// Invalid encrypted key
 | |
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA",
 | |
| 		// Invalid IV
 | |
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA",
 | |
| 		// Invalid ciphertext
 | |
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA",
 | |
| 		// Invalid tag
 | |
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////",
 | |
| 		// Invalid header
 | |
| 		"W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
 | |
| 		// Invalid header
 | |
| 		"######.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
 | |
| 		// Missing alc/enc params
 | |
| 		"e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
 | |
| 	}
 | |
| 
 | |
| 	for _, failure := range failures {
 | |
| 		getReq.Header = map[string][]string{
 | |
| 			"Authorization": {fmt.Sprintf("Bearer %s", failure)},
 | |
| 		}
 | |
| 		_, err := p.findBearerToken(getReq)
 | |
| 		assert.Error(t, err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_prepareNoCache(t *testing.T) {
 | |
| 	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		prepareNoCache(w)
 | |
| 	})
 | |
| 	mux := http.NewServeMux()
 | |
| 	mux.Handle("/", handler)
 | |
| 
 | |
| 	rec := httptest.NewRecorder()
 | |
| 	req := httptest.NewRequest(http.MethodGet, "/", nil)
 | |
| 	mux.ServeHTTP(rec, req)
 | |
| 
 | |
| 	for k, v := range noCacheHeaders {
 | |
| 		assert.Equal(t, rec.Header().Get(k), v)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Test_noCacheHeaders(t *testing.T) {
 | |
| 	upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		_, err := w.Write([]byte("upstream"))
 | |
| 		if err != nil {
 | |
| 			t.Error(err)
 | |
| 		}
 | |
| 	}))
 | |
| 	t.Cleanup(upstream.Close)
 | |
| 
 | |
| 	opts := baseTestOptions()
 | |
| 	opts.Upstreams = []string{upstream.URL}
 | |
| 	opts.SkipAuthRegex = []string{".*"}
 | |
| 	err := validation.Validate(opts)
 | |
| 	assert.NoError(t, err)
 | |
| 	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	t.Run("not exist in response from upstream", func(t *testing.T) {
 | |
| 		rec := httptest.NewRecorder()
 | |
| 		req := httptest.NewRequest(http.MethodGet, "/upstream", nil)
 | |
| 		proxy.ServeHTTP(rec, req)
 | |
| 
 | |
| 		assert.Equal(t, http.StatusOK, rec.Code)
 | |
| 		assert.Equal(t, "upstream", rec.Body.String())
 | |
| 
 | |
| 		// checking noCacheHeaders does not exists in response headers from upstream
 | |
| 		for k := range noCacheHeaders {
 | |
| 			assert.Equal(t, "", rec.Header().Get(k))
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	t.Run("has no-cache", func(t *testing.T) {
 | |
| 		tests := []struct {
 | |
| 			path       string
 | |
| 			hasNoCache bool
 | |
| 		}{
 | |
| 			{
 | |
| 				path:       "/oauth2/sign_in",
 | |
| 				hasNoCache: true,
 | |
| 			},
 | |
| 			{
 | |
| 				path:       "/oauth2/sign_out",
 | |
| 				hasNoCache: true,
 | |
| 			},
 | |
| 			{
 | |
| 				path:       "/oauth2/start",
 | |
| 				hasNoCache: true,
 | |
| 			},
 | |
| 			{
 | |
| 				path:       "/oauth2/callback",
 | |
| 				hasNoCache: true,
 | |
| 			},
 | |
| 			{
 | |
| 				path:       "/oauth2/auth",
 | |
| 				hasNoCache: false,
 | |
| 			},
 | |
| 			{
 | |
| 				path:       "/oauth2/userinfo",
 | |
| 				hasNoCache: true,
 | |
| 			},
 | |
| 			{
 | |
| 				path:       "/upstream",
 | |
| 				hasNoCache: false,
 | |
| 			},
 | |
| 		}
 | |
| 
 | |
| 		for _, tt := range tests {
 | |
| 			t.Run(tt.path, func(t *testing.T) {
 | |
| 				rec := httptest.NewRecorder()
 | |
| 				req := httptest.NewRequest(http.MethodGet, tt.path, nil)
 | |
| 				proxy.ServeHTTP(rec, req)
 | |
| 				cacheControl := rec.Result().Header.Get("Cache-Control")
 | |
| 				if tt.hasNoCache != (strings.Contains(cacheControl, "no-cache")) {
 | |
| 					t.Errorf(`unexpected "Cache-Control" header: %s`, cacheControl)
 | |
| 				}
 | |
| 			})
 | |
| 		}
 | |
| 
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func baseTestOptions() *options.Options {
 | |
| 	opts := options.NewOptions()
 | |
| 	opts.Cookie.Secret = rawCookieSecret
 | |
| 	opts.ClientID = clientID
 | |
| 	opts.ClientSecret = clientSecret
 | |
| 	opts.EmailDomains = []string{"*"}
 | |
| 	return opts
 | |
| }
 | |
| 
 | |
| func TestTrustedIPs(t *testing.T) {
 | |
| 	tests := []struct {
 | |
| 		name               string
 | |
| 		trustedIPs         []string
 | |
| 		reverseProxy       bool
 | |
| 		realClientIPHeader string
 | |
| 		req                *http.Request
 | |
| 		expectTrusted      bool
 | |
| 	}{
 | |
| 		// Check unconfigured behavior.
 | |
| 		{
 | |
| 			name:               "Default",
 | |
| 			trustedIPs:         nil,
 | |
| 			reverseProxy:       false,
 | |
| 			realClientIPHeader: "X-Real-IP", // Default value
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: false,
 | |
| 		},
 | |
| 		// Check using req.RemoteAddr (Options.ReverseProxy == false).
 | |
| 		{
 | |
| 			name:               "WithRemoteAddr",
 | |
| 			trustedIPs:         []string{"127.0.0.1"},
 | |
| 			reverseProxy:       false,
 | |
| 			realClientIPHeader: "X-Real-IP", // Default value
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.RemoteAddr = "127.0.0.1:43670"
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: true,
 | |
| 		},
 | |
| 		// Check ignores req.RemoteAddr match when behind a reverse proxy / missing header.
 | |
| 		{
 | |
| 			name:               "IgnoresRemoteAddrInReverseProxyMode",
 | |
| 			trustedIPs:         []string{"127.0.0.1"},
 | |
| 			reverseProxy:       true,
 | |
| 			realClientIPHeader: "X-Real-IP", // Default value
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.RemoteAddr = "127.0.0.1:44324"
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: false,
 | |
| 		},
 | |
| 		// Check successful trusting of localhost in IPv4.
 | |
| 		{
 | |
| 			name:               "TrustsLocalhostInReverseProxyMode",
 | |
| 			trustedIPs:         []string{"127.0.0.0/8", "::1"},
 | |
| 			reverseProxy:       true,
 | |
| 			realClientIPHeader: "X-Forwarded-For",
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.Header.Add("X-Forwarded-For", "127.0.0.1")
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: true,
 | |
| 		},
 | |
| 		// Check successful trusting of localhost in IPv6.
 | |
| 		{
 | |
| 			name:               "TrustsIP6LocalostInReverseProxyMode",
 | |
| 			trustedIPs:         []string{"127.0.0.0/8", "::1"},
 | |
| 			reverseProxy:       true,
 | |
| 			realClientIPHeader: "X-Forwarded-For",
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.Header.Add("X-Forwarded-For", "::1")
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: true,
 | |
| 		},
 | |
| 		// Check does not trust random IPv4 address.
 | |
| 		{
 | |
| 			name:               "DoesNotTrustRandomIP4Address",
 | |
| 			trustedIPs:         []string{"127.0.0.0/8", "::1"},
 | |
| 			reverseProxy:       true,
 | |
| 			realClientIPHeader: "X-Forwarded-For",
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.Header.Add("X-Forwarded-For", "12.34.56.78")
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: false,
 | |
| 		},
 | |
| 		// Check does not trust random IPv6 address.
 | |
| 		{
 | |
| 			name:               "DoesNotTrustRandomIP6Address",
 | |
| 			trustedIPs:         []string{"127.0.0.0/8", "::1"},
 | |
| 			reverseProxy:       true,
 | |
| 			realClientIPHeader: "X-Forwarded-For",
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.Header.Add("X-Forwarded-For", "::2")
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: false,
 | |
| 		},
 | |
| 		// Check respects correct header.
 | |
| 		{
 | |
| 			name:               "RespectsCorrectHeaderInReverseProxyMode",
 | |
| 			trustedIPs:         []string{"127.0.0.0/8", "::1"},
 | |
| 			reverseProxy:       true,
 | |
| 			realClientIPHeader: "X-Forwarded-For",
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.Header.Add("X-Real-IP", "::1")
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: false,
 | |
| 		},
 | |
| 		// Check doesn't trust if garbage is provided.
 | |
| 		{
 | |
| 			name:               "DoesNotTrustGarbageInReverseProxyMode",
 | |
| 			trustedIPs:         []string{"127.0.0.0/8", "::1"},
 | |
| 			reverseProxy:       true,
 | |
| 			realClientIPHeader: "X-Forwarded-For",
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.Header.Add("X-Forwarded-For", "adsfljk29242as!!")
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: false,
 | |
| 		},
 | |
| 		// Check doesn't trust if garbage is provided (no reverse-proxy).
 | |
| 		{
 | |
| 			name:               "DoesNotTrustGarbage",
 | |
| 			trustedIPs:         []string{"127.0.0.0/8", "::1"},
 | |
| 			reverseProxy:       false,
 | |
| 			realClientIPHeader: "X-Real-IP",
 | |
| 			req: func() *http.Request {
 | |
| 				req, _ := http.NewRequest("GET", "/", nil)
 | |
| 				req.RemoteAddr = "adsfljk29242as!!"
 | |
| 				return req
 | |
| 			}(),
 | |
| 			expectTrusted: false,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for _, tt := range tests {
 | |
| 		t.Run(tt.name, func(t *testing.T) {
 | |
| 			opts := baseTestOptions()
 | |
| 			opts.Upstreams = []string{"static://200"}
 | |
| 			opts.TrustedIPs = tt.trustedIPs
 | |
| 			opts.ReverseProxy = tt.reverseProxy
 | |
| 			opts.RealClientIPHeader = tt.realClientIPHeader
 | |
| 			validation.Validate(opts)
 | |
| 
 | |
| 			proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 | |
| 			assert.NoError(t, err)
 | |
| 			rw := httptest.NewRecorder()
 | |
| 
 | |
| 			proxy.ServeHTTP(rw, tt.req)
 | |
| 			if tt.expectTrusted {
 | |
| 				assert.Equal(t, 200, rw.Code)
 | |
| 			} else {
 | |
| 				assert.Equal(t, 403, rw.Code)
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 |