128 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			128 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Go
		
	
	
	
package middleware
 | 
						|
 | 
						|
import (
 | 
						|
	"net/http"
 | 
						|
	"net/http/httptest"
 | 
						|
 | 
						|
	"github.com/google/uuid"
 | 
						|
	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
 | 
						|
	. "github.com/onsi/ginkgo"
 | 
						|
	. "github.com/onsi/gomega"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	testRequestHeader = "X-Request-Id"
 | 
						|
	testRequestID     = "11111111-2222-4333-8444-555555555555"
 | 
						|
	// mockRand io.Reader below counts bytes from 0-255 in order
 | 
						|
	testRandomUUID = "00010203-0405-4607-8809-0a0b0c0d0e0f"
 | 
						|
)
 | 
						|
 | 
						|
var _ = Describe("Scope Suite", func() {
 | 
						|
	Context("NewScope", func() {
 | 
						|
		var request, nextRequest *http.Request
 | 
						|
		var rw http.ResponseWriter
 | 
						|
 | 
						|
		BeforeEach(func() {
 | 
						|
			var err error
 | 
						|
			request, err = http.NewRequest("", "http://127.0.0.1/", nil)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			rw = httptest.NewRecorder()
 | 
						|
		})
 | 
						|
 | 
						|
		Context("ReverseProxy is false", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				handler := NewScope(false, testRequestHeader)(
 | 
						|
					http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						|
						nextRequest = r
 | 
						|
						w.WriteHeader(200)
 | 
						|
					}))
 | 
						|
				handler.ServeHTTP(rw, request)
 | 
						|
			})
 | 
						|
 | 
						|
			It("does not add a scope to the original request", func() {
 | 
						|
				Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil())
 | 
						|
			})
 | 
						|
 | 
						|
			It("cannot load a scope from the original request using GetRequestScope", func() {
 | 
						|
				Expect(middlewareapi.GetRequestScope(request)).To(BeNil())
 | 
						|
			})
 | 
						|
 | 
						|
			It("adds a scope to the request for the next handler", func() {
 | 
						|
				Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil())
 | 
						|
			})
 | 
						|
 | 
						|
			It("can load a scope from the next handler's request using GetRequestScope", func() {
 | 
						|
				scope := middlewareapi.GetRequestScope(nextRequest)
 | 
						|
				Expect(scope).ToNot(BeNil())
 | 
						|
				Expect(scope.ReverseProxy).To(BeFalse())
 | 
						|
			})
 | 
						|
		})
 | 
						|
 | 
						|
		Context("ReverseProxy is true", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				handler := NewScope(true, testRequestHeader)(
 | 
						|
					http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						|
						nextRequest = r
 | 
						|
						w.WriteHeader(200)
 | 
						|
					}))
 | 
						|
				handler.ServeHTTP(rw, request)
 | 
						|
			})
 | 
						|
 | 
						|
			It("return a scope where the ReverseProxy field is true", func() {
 | 
						|
				scope := middlewareapi.GetRequestScope(nextRequest)
 | 
						|
				Expect(scope).ToNot(BeNil())
 | 
						|
				Expect(scope.ReverseProxy).To(BeTrue())
 | 
						|
			})
 | 
						|
		})
 | 
						|
 | 
						|
		Context("Request ID header is present", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				request.Header.Add(testRequestHeader, testRequestID)
 | 
						|
				handler := NewScope(false, testRequestHeader)(
 | 
						|
					http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						|
						nextRequest = r
 | 
						|
						w.WriteHeader(200)
 | 
						|
					}))
 | 
						|
				handler.ServeHTTP(rw, request)
 | 
						|
			})
 | 
						|
 | 
						|
			It("sets the RequestID using the request", func() {
 | 
						|
				scope := middlewareapi.GetRequestScope(nextRequest)
 | 
						|
				Expect(scope.RequestID).To(Equal(testRequestID))
 | 
						|
			})
 | 
						|
		})
 | 
						|
 | 
						|
		Context("Request ID header is missing", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				uuid.SetRand(mockRand{})
 | 
						|
 | 
						|
				handler := NewScope(true, testRequestHeader)(
 | 
						|
					http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						|
						nextRequest = r
 | 
						|
						w.WriteHeader(200)
 | 
						|
					}))
 | 
						|
				handler.ServeHTTP(rw, request)
 | 
						|
			})
 | 
						|
 | 
						|
			AfterEach(func() {
 | 
						|
				uuid.SetRand(nil)
 | 
						|
			})
 | 
						|
 | 
						|
			It("sets the RequestID using a random UUID", func() {
 | 
						|
				scope := middlewareapi.GetRequestScope(nextRequest)
 | 
						|
				Expect(scope.RequestID).To(Equal(testRandomUUID))
 | 
						|
			})
 | 
						|
		})
 | 
						|
	})
 | 
						|
})
 | 
						|
 | 
						|
type mockRand struct{}
 | 
						|
 | 
						|
func (mockRand) Read(p []byte) (int, error) {
 | 
						|
	for i := range p {
 | 
						|
		p[i] = byte(i % 256)
 | 
						|
	}
 | 
						|
	return len(p), nil
 | 
						|
}
 |