From 537e596904fe6e1f01e86d97c1997129925a1f00 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 26 Jun 2021 11:48:49 +0100 Subject: [PATCH] Add claim extractor provider util --- go.mod | 1 + pkg/providers/util/claim_extractor.go | 210 ++++++++ pkg/providers/util/claim_extractor_test.go | 530 +++++++++++++++++++++ pkg/providers/util/util_suite_test.go | 17 + 4 files changed, 758 insertions(+) create mode 100644 pkg/providers/util/claim_extractor.go create mode 100644 pkg/providers/util/claim_extractor_test.go create mode 100644 pkg/providers/util/util_suite_test.go diff --git a/go.mod b/go.mod index ed1229d5..08a2e390 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/onsi/gomega v1.10.2 github.com/pierrec/lz4 v2.5.2+incompatible github.com/prometheus/client_golang v1.9.0 + github.com/spf13/cast v1.3.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.6.3 github.com/stretchr/testify v1.6.1 diff --git a/pkg/providers/util/claim_extractor.go b/pkg/providers/util/claim_extractor.go new file mode 100644 index 00000000..f0fe320e --- /dev/null +++ b/pkg/providers/util/claim_extractor.go @@ -0,0 +1,210 @@ +package util + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/bitly/go-simplejson" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" + "github.com/spf13/cast" +) + +// ClaimExtractor is used to extract claim values from an ID Token, or, if not +// present, from the profile URL. +type ClaimExtractor interface { + // GetClaim fetches a named claim and returns the value. + GetClaim(claim string) (interface{}, bool, error) + + // GetClaimInto fetches a named claim and puts the value into the destination. + GetClaimInto(claim string, dst interface{}) (bool, error) +} + +// NewClaimExtractor constructs a new ClaimExtractor from the raw ID Token. +// If needed, it will use the profile URL to look up a claim if it isn't present +// within the ID Token. +func NewClaimExtractor(ctx context.Context, idToken string, profileURL *url.URL, profileRequestHeaders http.Header) (ClaimExtractor, error) { + payload, err := parseJWT(idToken) + if err != nil { + return nil, fmt.Errorf("failed to parse ID Token: %v", err) + } + + tokenClaims, err := simplejson.NewJson(payload) + if err != nil { + return nil, fmt.Errorf("failed to parse ID Token payload: %v", err) + } + + return &claimExtractor{ + ctx: ctx, + profileURL: profileURL, + requestHeaders: profileRequestHeaders, + tokenClaims: tokenClaims, + }, nil +} + +// claimExtractor implements the ClaimExtractor interface +type claimExtractor struct { + profileURL *url.URL + ctx context.Context + requestHeaders map[string][]string + tokenClaims *simplejson.Json + profileClaims *simplejson.Json +} + +// GetClaim will return the value claim if it exists. +// It will only return an error if the profile URL needs to be fetched due to +// the claim not being present in the ID Token. +func (c *claimExtractor) GetClaim(claim string) (interface{}, bool, error) { + if claim == "" { + return nil, false, nil + } + + if value := getClaimFrom(claim, c.tokenClaims); value != nil { + return value, true, nil + } + + if c.profileClaims == nil { + profileClaims, err := c.loadProfileClaims() + if err != nil { + return nil, false, fmt.Errorf("failed to fetch claims from profile URL: %v", err) + } + + c.profileClaims = profileClaims + } + + if value := getClaimFrom(claim, c.profileClaims); value != nil { + return value, true, nil + } + + return nil, false, nil +} + +// loadProfileClaims will fetch the profileURL using the provided headers as +// authentication. +func (c *claimExtractor) loadProfileClaims() (*simplejson.Json, error) { + if c.profileURL == nil || c.requestHeaders == nil { + // When no profileURL is set, we return a non-empty map so that + // we don't attempt to populate the profile claims again. + // If there are no headers, the request would be unauthorized so we also skip + // in this case too. + return simplejson.New(), nil + } + + claims, err := requests.New(c.profileURL.String()). + WithContext(c.ctx). + WithHeaders(c.requestHeaders). + Do(). + UnmarshalJSON() + if err != nil { + return nil, fmt.Errorf("error making request to profile URL: %v", err) + } + + return claims, nil +} + +// GetClaimInto loads a claim and places it into the destination interface. +// This will attempt to coerce the claim into the specified type. +// If it cannot be coerced, an error may be returned. +func (c *claimExtractor) GetClaimInto(claim string, dst interface{}) (bool, error) { + value, exists, err := c.GetClaim(claim) + if err != nil { + return false, fmt.Errorf("could not get claim %q: %v", claim, err) + } + if !exists { + return false, nil + } + if err := coerceClaim(value, dst); err != nil { + return false, fmt.Errorf("could no coerce claim: %v", err) + } + + return true, nil +} + +// This has been copied from https://github.com/coreos/go-oidc/blob/8d771559cf6e5111c9b9159810d0e4538e7cdc82/verify.go#L120-L130 +// We use it to grab the raw ID Token payload so that we can parse it into the JSON library. +func parseJWT(p string) ([]byte, error) { + parts := strings.Split(p, ".") + if len(parts) < 2 { + return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts)) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err) + } + return payload, nil +} + +// getClaimFrom gets a claim from a Json object. +// It can accept either a single claim name or a json path. +// Paths with indexes are not supported. +func getClaimFrom(claim string, src *simplejson.Json) interface{} { + claimParts := strings.Split(claim, ".") + return src.GetPath(claimParts...).Interface() +} + +// coerceClaim tries to convert the value into the destination interface type. +// If it can convert the value, it will then store the value in the destination +// interface. +func coerceClaim(value, dst interface{}) error { + switch d := dst.(type) { + case *string: + str, err := toString(value) + if err != nil { + return fmt.Errorf("could not convert value to string: %v", err) + } + *d = str + case *[]string: + strSlice, err := toStringSlice(value) + if err != nil { + return fmt.Errorf("could not convert value to string slice: %v", err) + } + *d = strSlice + case *bool: + *d = cast.ToBool(value) + default: + return fmt.Errorf("unknown type for destination: %T", dst) + } + return nil +} + +// toStringSlice converts an interface (either a slice or single value) into +// a slice of strings. +func toStringSlice(value interface{}) ([]string, error) { + var sliceValues []interface{} + switch v := value.(type) { + case []interface{}: + sliceValues = v + case interface{}: + sliceValues = []interface{}{v} + default: + sliceValues = cast.ToSlice(value) + } + + out := []string{} + for _, v := range sliceValues { + str, err := toString(v) + if err != nil { + return nil, fmt.Errorf("could not convert slice entry to string %v: %v", v, err) + } + out = append(out, str) + } + return out, nil +} + +// toString coerces a value into a string. +// If it is non-string, marshal it into JSON. +func toString(value interface{}) (string, error) { + if str, err := cast.ToStringE(value); err == nil { + return str, nil + } + + jsonStr, err := json.Marshal(value) + if err != nil { + return "", err + } + return string(jsonStr), nil +} diff --git a/pkg/providers/util/claim_extractor_test.go b/pkg/providers/util/claim_extractor_test.go new file mode 100644 index 00000000..fb6220fe --- /dev/null +++ b/pkg/providers/util/claim_extractor_test.go @@ -0,0 +1,530 @@ +package util + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "sync/atomic" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +const ( + emptyJSON = "{}" + profilePath = "/userinfo" + authorizedAccessToken = "valid_access_token" + basicIDTokenPayload = `{ + "user": "idTokenUser", + "email": "idTokenEmail", + "groups": [ + "idTokenGroup1", + "idTokenGroup2" + ] + }` + basicProfileURLPayload = `{ + "user": "profileUser", + "email": "profileEmail", + "groups": [ + "profileGroup1", + "profileGroup2" + ] + }` + nestedClaimPayload = `{ + "auth": { + "user": { + "username": "nestedUser" + } + } + }` + complexGroupsPayload = `{ + "groups": [ + { + "groupID": "group1", + "roles": ["admin"] + }, + { + "groupID": "group2", + "roles": ["user", "employee"] + } + ] + }` +) + +var _ = Describe("Claim Extractor Suite", func() { + Context("Claim Extractor", func() { + type newClaimExtractorTableInput struct { + idToken string + expectedError error + } + + DescribeTable("NewClaimExtractor", + func(in newClaimExtractorTableInput) { + _, err := NewClaimExtractor(context.Background(), in.idToken, nil, nil) + if in.expectedError != nil { + Expect(err).To(MatchError(in.expectedError)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + }, + Entry("with a valid JWT", newClaimExtractorTableInput{ + idToken: createJWTFromPayload(basicIDTokenPayload), + expectedError: nil, + }), + Entry("with a JWT with a non-json payload", newClaimExtractorTableInput{ + idToken: createJWTFromPayload("this is not JSON"), + expectedError: errors.New("failed to parse ID Token payload: invalid character 'h' in literal true (expecting 'r')"), + }), + Entry("with an IDToken with the wrong number of parts", newClaimExtractorTableInput{ + idToken: "eyJeyJ", + expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt, expected 3 parts got 1"), + }), + Entry("with an non-base64 IDToken", newClaimExtractorTableInput{ + idToken: "{metadata}.{payload}.{signature}", + expectedError: errors.New("failed to parse ID Token: oidc: malformed jwt payload: illegal base64 data at input byte 0"), + }), + ) + + type getClaimTableInput struct { + testClaimExtractorOpts + claim string + expectedValue interface{} + expectExists bool + expectedError error + } + + DescribeTable("GetClaim", + func(in getClaimTableInput) { + claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts) + Expect(err).ToNot(HaveOccurred()) + if serverClose != nil { + defer serverClose() + } + + value, exists, err := claimExtractor.GetClaim(in.claim) + if in.expectedError != nil { + Expect(err).To(MatchError(in.expectedError)) + return + } + + Expect(err).ToNot(HaveOccurred()) + if in.expectedValue != nil { + Expect(value).To(Equal(in.expectedValue)) + } else { + Expect(value).To(BeNil()) + } + + Expect(exists).To(Equal(in.expectExists)) + }, + Entry("retrieves a string claim from ID Token when present", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "user", + expectExists: true, + expectedValue: "idTokenUser", + expectedError: nil, + }), + Entry("retrieves a slice claim from ID Token when present", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "groups", + expectExists: true, + expectedValue: []interface{}{"idTokenGroup1", "idTokenGroup2"}, + expectedError: nil, + }), + Entry("when the requested claim is the empty string", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + }, + claim: "", + expectExists: false, + expectedValue: nil, + expectedError: nil, + }), + Entry("when the requested claim is the not found (with no profile URL)", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + profileRequestHeaders: newAuthorizedHeader(), + }, + claim: "not_found", + expectExists: false, + expectedValue: nil, + expectedError: nil, + }), + Entry("when the requested claim is the not found (with profile URL)", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "not_found", + expectExists: false, + expectedValue: nil, + expectedError: nil, + }), + Entry("when the requested claim is the not found (with no profile Headers)", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: nil, + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "not_found", + expectExists: false, + expectedValue: nil, + expectedError: nil, + }), + Entry("when the profile URL is unauthorized", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: emptyJSON, + setProfileURL: true, + profileRequestHeaders: make(http.Header), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "user", + expectExists: false, + expectedValue: nil, + expectedError: errors.New("failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"), + }), + Entry("retrieves a string claim from profile URL when not present in the ID Token", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: emptyJSON, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "user", + expectExists: true, + expectedValue: "profileUser", + expectedError: nil, + }), + Entry("retrieves a string claim from a nested path", getClaimTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: nestedClaimPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "auth.user.username", + expectExists: true, + expectedValue: "nestedUser", + expectedError: nil, + }), + ) + }) + + It("GetClaim should only call the profile URL once", func() { + var counter int32 + countRequestsHandler := func(rw http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&counter, 1) + rw.Write([]byte(basicProfileURLPayload)) + } + + claimExtractor, serverClose, err := newTestClaimExtractor(testClaimExtractorOpts{ + idTokenPayload: "{}", + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: countRequestsHandler, + }) + Expect(err).ToNot(HaveOccurred()) + if serverClose != nil { + defer serverClose() + } + + value, exists, err := claimExtractor.GetClaim("user") + Expect(err).ToNot(HaveOccurred()) + Expect(exists).To(BeTrue()) + Expect(value).To(Equal("profileUser")) + Expect(counter).To(BeEquivalentTo(1)) + + // Check a different claim, but expect the count not to increase + value, exists, err = claimExtractor.GetClaim("email") + Expect(err).ToNot(HaveOccurred()) + Expect(exists).To(BeTrue()) + Expect(value).To(Equal("profileEmail")) + Expect(counter).To(BeEquivalentTo(1)) + }) + + type getClaimIntoTableInput struct { + testClaimExtractorOpts + into interface{} + claim string + expectedValue interface{} + expectExists bool + expectedError error + } + + DescribeTable("GetClaimInto", + func(in getClaimIntoTableInput) { + claimExtractor, serverClose, err := newTestClaimExtractor(in.testClaimExtractorOpts) + Expect(err).ToNot(HaveOccurred()) + if serverClose != nil { + defer serverClose() + } + + exists, err := claimExtractor.GetClaimInto(in.claim, in.into) + if in.expectedError != nil { + Expect(err).To(MatchError(in.expectedError)) + return + } + + Expect(err).ToNot(HaveOccurred()) + if in.expectedValue != nil { + Expect(in.into).To(Equal(in.expectedValue)) + } else { + Expect(in.into).To(BeEmpty()) + } + + Expect(exists).To(Equal(in.expectExists)) + }, + Entry("retrieves a string claim from ID Token when present into a string", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "user", + into: stringPointer(""), + expectExists: true, + expectedValue: stringPointer("idTokenUser"), + expectedError: nil, + }), + Entry("retrieves a string claim from ID Token when present into a string slice", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "user", + into: stringSlicePointer([]string{}), + expectExists: true, + expectedValue: stringSlicePointer([]string{"idTokenUser"}), + expectedError: nil, + }), + Entry("retrieves a string slice claim from ID Token when present into a string slice", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "groups", + into: stringSlicePointer([]string{}), + expectExists: true, + expectedValue: stringSlicePointer([]string{"idTokenGroup1", "idTokenGroup2"}), + expectedError: nil, + }), + Entry("retrieves a string slice claim from ID Token when present into a string", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "groups", + into: stringPointer(""), + expectExists: true, + expectedValue: stringPointer("[\"idTokenGroup1\",\"idTokenGroup2\"]"), + expectedError: nil, + }), + Entry("returns an error when a non-pointer is passed for the destination", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "user", + into: "", + expectExists: false, + expectedValue: "", + expectedError: errors.New("could no coerce claim: unknown type for destination: string"), + }), + Entry("flattens a complex claim value into a JSON string", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: complexGroupsPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: shouldNotBeRequestedProfileHandler, + }, + claim: "groups", + into: stringSlicePointer([]string{}), + expectExists: true, + expectedValue: stringSlicePointer([]string{ + "{\"groupID\":\"group1\",\"roles\":[\"admin\"]}", + "{\"groupID\":\"group2\",\"roles\":[\"user\",\"employee\"]}", + }), + expectedError: nil, + }), + Entry("does not return an error when the claim does not exist", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: basicIDTokenPayload, + setProfileURL: true, + profileRequestHeaders: newAuthorizedHeader(), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "not_found", + into: stringPointer(""), + expectExists: false, + expectedValue: stringPointer(""), + expectedError: nil, + }), + Entry("returns an error when the profile request is unauthorized", getClaimIntoTableInput{ + testClaimExtractorOpts: testClaimExtractorOpts{ + idTokenPayload: emptyJSON, + setProfileURL: true, + profileRequestHeaders: make(http.Header), + profileRequestHandler: requiresAuthProfileHandler, + }, + claim: "user", + into: stringPointer(""), + expectExists: false, + expectedValue: stringPointer(""), + expectedError: errors.New("could not get claim \"user\": failed to fetch claims from profile URL: error making request to profile URL: unexpected status \"403\": Unauthorized"), + }), + ) + + type coerceClaimTableInput struct { + value interface{} + dst interface{} + expectedDst interface{} + expectedError error + } + + DescribeTable("coerceClaim", + func(in coerceClaimTableInput) { + err := coerceClaim(in.value, in.dst) + if in.expectedError != nil { + Expect(err).To(MatchError(in.expectedError)) + return + } + + Expect(err).ToNot(HaveOccurred()) + Expect(in.dst).To(Equal(in.expectedDst)) + }, + Entry("coerces a string to a string", coerceClaimTableInput{ + value: "some_string", + dst: stringPointer(""), + expectedDst: stringPointer("some_string"), + }), + Entry("coerces a slice to a string slice", coerceClaimTableInput{ + value: []interface{}{"a", "b"}, + dst: stringSlicePointer([]string{}), + expectedDst: stringSlicePointer([]string{"a", "b"}), + }), + Entry("coerces a bool to a bool", coerceClaimTableInput{ + value: true, + dst: boolPointer(false), + expectedDst: boolPointer(true), + }), + Entry("coerces a string to a bool", coerceClaimTableInput{ + value: "true", + dst: boolPointer(false), + expectedDst: boolPointer(true), + }), + Entry("coerces a map to a string", coerceClaimTableInput{ + value: map[string]interface{}{ + "foo": []interface{}{"bar", "baz"}, + }, + dst: stringPointer(""), + expectedDst: stringPointer("{\"foo\":[\"bar\",\"baz\"]}"), + }), + ) +}) + +// ****************************************** +// Helpers for setting up the claim extractor +// ****************************************** + +type testClaimExtractorOpts struct { + idTokenPayload string + setProfileURL bool + profileRequestHeaders http.Header + profileRequestHandler http.HandlerFunc +} + +func newTestClaimExtractor(in testClaimExtractorOpts) (ClaimExtractor, func(), error) { + var profileURL *url.URL + var closeServer func() + if in.setProfileURL { + server := httptest.NewServer(http.HandlerFunc(in.profileRequestHandler)) + closeServer = server.Close + + var err error + profileURL, err = url.Parse("http://" + server.Listener.Addr().String() + profilePath) + Expect(err).ToNot(HaveOccurred()) + } + + rawIDToken := createJWTFromPayload(in.idTokenPayload) + + claimExtractor, err := NewClaimExtractor(context.Background(), rawIDToken, profileURL, in.profileRequestHeaders) + return claimExtractor, closeServer, err +} + +func createJWTFromPayload(payload string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(emptyJSON)) + payloadJSON := base64.RawURLEncoding.EncodeToString([]byte(payload)) + + return fmt.Sprintf("%s.%s.%s", header, payloadJSON, header) +} + +func newAuthorizedHeader() http.Header { + headers := make(http.Header) + headers.Add("Authorization", "Bearer "+authorizedAccessToken) + return headers +} + +func hasAuthorizedHeader(headers http.Header) bool { + return headers.Get("Authorization") == "Bearer "+authorizedAccessToken +} + +// *********************** +// Typed Pointer Functions +// *********************** + +func stringPointer(in string) *string { + return &in +} + +func stringSlicePointer(in []string) *[]string { + return &in +} + +func boolPointer(in bool) *bool { + return &in +} + +// ****************************** +// Different profile URL handlers +// ****************************** + +func shouldNotBeRequestedProfileHandler(_ http.ResponseWriter, _ *http.Request) { + defer GinkgoRecover() + Expect(true).To(BeFalse(), "Unexpected request to profile URL") +} + +func requiresAuthProfileHandler(rw http.ResponseWriter, req *http.Request) { + if !hasAuthorizedHeader(req.Header) { + rw.WriteHeader(403) + rw.Write([]byte("Unauthorized")) + return + } + + rw.Write([]byte(basicProfileURLPayload)) +} diff --git a/pkg/providers/util/util_suite_test.go b/pkg/providers/util/util_suite_test.go new file mode 100644 index 00000000..c76d9813 --- /dev/null +++ b/pkg/providers/util/util_suite_test.go @@ -0,0 +1,17 @@ +package util + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestProviderUtilSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + logger.SetErrOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Provider Utils") +}