Merge pull request #660 from oauth2-proxy/request-builder
Use builder pattern to simplify requests to external endpoints
This commit is contained in:
		
						commit
						d29766609b
					
				|  | @ -8,6 +8,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v6.0.0 | ## Changes since v6.0.0 | ||||||
| 
 | 
 | ||||||
|  | - [#660](https://github.com/oauth2-proxy/oauth2-proxy/pull/660) Use builder pattern to simplify requests to external endpoints (@JoelSpeed) | ||||||
| - [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed) | - [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed) | ||||||
| - [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed) | - [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed) | ||||||
| - [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves) | - [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,118 @@ | ||||||
|  | package requests | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Builder allows users to construct a request and then execute the
 | ||||||
|  | // request via Do().
 | ||||||
|  | // Do returns a Result which allows the user to get the body,
 | ||||||
|  | // unmarshal the body into an interface, or into a simplejson.Json.
 | ||||||
|  | type Builder interface { | ||||||
|  | 	WithContext(context.Context) Builder | ||||||
|  | 	WithBody(io.Reader) Builder | ||||||
|  | 	WithMethod(string) Builder | ||||||
|  | 	WithHeaders(http.Header) Builder | ||||||
|  | 	SetHeader(key, value string) Builder | ||||||
|  | 	Do() Result | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type builder struct { | ||||||
|  | 	context  context.Context | ||||||
|  | 	method   string | ||||||
|  | 	endpoint string | ||||||
|  | 	body     io.Reader | ||||||
|  | 	header   http.Header | ||||||
|  | 	result   *result | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // New provides a new Builder for the given endpoint.
 | ||||||
|  | func New(endpoint string) Builder { | ||||||
|  | 	return &builder{ | ||||||
|  | 		endpoint: endpoint, | ||||||
|  | 		method:   "GET", | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WithContext adds a context to the request.
 | ||||||
|  | // If no context is provided, context.Background() is used instead.
 | ||||||
|  | func (r *builder) WithContext(ctx context.Context) Builder { | ||||||
|  | 	r.context = ctx | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WithBody adds a body to the request.
 | ||||||
|  | func (r *builder) WithBody(body io.Reader) Builder { | ||||||
|  | 	r.body = body | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WithMethod sets the request method. Defaults to "GET".
 | ||||||
|  | func (r *builder) WithMethod(method string) Builder { | ||||||
|  | 	r.method = method | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // WithHeaders replaces the request header map with the given header map.
 | ||||||
|  | func (r *builder) WithHeaders(header http.Header) Builder { | ||||||
|  | 	r.header = header | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetHeader sets a single header to the given value.
 | ||||||
|  | // May be used to add multiple headers.
 | ||||||
|  | func (r *builder) SetHeader(key, value string) Builder { | ||||||
|  | 	if r.header == nil { | ||||||
|  | 		r.header = make(http.Header) | ||||||
|  | 	} | ||||||
|  | 	r.header.Set(key, value) | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Do performs the request and returns the response in its raw form.
 | ||||||
|  | // If the request has already been performed, returns the previous result.
 | ||||||
|  | // This will not allow you to repeat a request.
 | ||||||
|  | func (r *builder) Do() Result { | ||||||
|  | 	if r.result != nil { | ||||||
|  | 		// Request has already been done
 | ||||||
|  | 		return r.result | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Must provide a non-nil context to NewRequestWithContext
 | ||||||
|  | 	if r.context == nil { | ||||||
|  | 		r.context = context.Background() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return r.do() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // do creates the request, executes it with the default client and extracts the
 | ||||||
|  | // the body into the response
 | ||||||
|  | func (r *builder) do() Result { | ||||||
|  | 	req, err := http.NewRequestWithContext(r.context, r.method, r.endpoint, r.body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.result = &result{err: fmt.Errorf("error creating request: %v", err)} | ||||||
|  | 		return r.result | ||||||
|  | 	} | ||||||
|  | 	req.Header = r.header | ||||||
|  | 
 | ||||||
|  | 	resp, err := http.DefaultClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.result = &result{err: fmt.Errorf("error performing request: %v", err)} | ||||||
|  | 		return r.result | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	defer resp.Body.Close() | ||||||
|  | 	body, err := ioutil.ReadAll(resp.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.result = &result{err: fmt.Errorf("error reading response body: %v", err)} | ||||||
|  | 		return r.result | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	r.result = &result{response: resp, body: body} | ||||||
|  | 	return r.result | ||||||
|  | } | ||||||
|  | @ -0,0 +1,376 @@ | ||||||
|  | package requests | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"context" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/bitly/go-simplejson" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Builder suite", func() { | ||||||
|  | 	var b Builder | ||||||
|  | 	getBuilder := func() Builder { return b } | ||||||
|  | 
 | ||||||
|  | 	baseHeaders := http.Header{ | ||||||
|  | 		"Accept-Encoding": []string{"gzip"}, | ||||||
|  | 		"User-Agent":      []string{"Go-http-client/1.1"}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		// Most tests will request the server address
 | ||||||
|  | 		b = New(serverAddr + "/json/path") | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("with a basic request", func() { | ||||||
|  | 		assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 			Method:     "GET", | ||||||
|  | 			Header:     baseHeaders, | ||||||
|  | 			Body:       []byte{}, | ||||||
|  | 			RequestURI: "/json/path", | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("with a context", func() { | ||||||
|  | 		var ctx context.Context | ||||||
|  | 		var cancel context.CancelFunc | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			ctx, cancel = context.WithCancel(context.Background()) | ||||||
|  | 			b = b.WithContext(ctx) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		AfterEach(func() { | ||||||
|  | 			cancel() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 			Method:     "GET", | ||||||
|  | 			Header:     baseHeaders, | ||||||
|  | 			Body:       []byte{}, | ||||||
|  | 			RequestURI: "/json/path", | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("if the context is cancelled", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				cancel() | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			assertRequestError(getBuilder, "context canceled") | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("with a body", func() { | ||||||
|  | 		const body = "{\"some\": \"body\"}" | ||||||
|  | 		header := baseHeaders.Clone() | ||||||
|  | 		header.Set("Content-Length", fmt.Sprintf("%d", len(body))) | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			buf := bytes.NewBuffer([]byte(body)) | ||||||
|  | 			b = b.WithBody(buf) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 			Method:     "GET", | ||||||
|  | 			Header:     header, | ||||||
|  | 			Body:       []byte(body), | ||||||
|  | 			RequestURI: "/json/path", | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("with a method", func() { | ||||||
|  | 		Context("POST with a body", func() { | ||||||
|  | 			const body = "{\"some\": \"body\"}" | ||||||
|  | 			header := baseHeaders.Clone() | ||||||
|  | 			header.Set("Content-Length", fmt.Sprintf("%d", len(body))) | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				buf := bytes.NewBuffer([]byte(body)) | ||||||
|  | 				b = b.WithMethod("POST").WithBody(buf) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 				Method:     "POST", | ||||||
|  | 				Header:     header, | ||||||
|  | 				Body:       []byte(body), | ||||||
|  | 				RequestURI: "/json/path", | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("POST without a body", func() { | ||||||
|  | 			header := baseHeaders.Clone() | ||||||
|  | 			header.Set("Content-Length", "0") | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				b = b.WithMethod("POST") | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 				Method:     "POST", | ||||||
|  | 				Header:     header, | ||||||
|  | 				Body:       []byte{}, | ||||||
|  | 				RequestURI: "/json/path", | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("OPTIONS", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				b = b.WithMethod("OPTIONS") | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 				Method:     "OPTIONS", | ||||||
|  | 				Header:     baseHeaders, | ||||||
|  | 				Body:       []byte{}, | ||||||
|  | 				RequestURI: "/json/path", | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("INVALID-\\t-METHOD", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				b = b.WithMethod("INVALID-\t-METHOD") | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			assertRequestError(getBuilder, "error creating request: net/http: invalid method \"INVALID-\\t-METHOD\"") | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("with headers", func() { | ||||||
|  | 		Context("setting a header", func() { | ||||||
|  | 			header := baseHeaders.Clone() | ||||||
|  | 			header.Set("header", "value") | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				b = b.SetHeader("header", "value") | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 				Method:     "GET", | ||||||
|  | 				Header:     header, | ||||||
|  | 				Body:       []byte{}, | ||||||
|  | 				RequestURI: "/json/path", | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("then replacing the headers", func() { | ||||||
|  | 				replacementHeaders := http.Header{ | ||||||
|  | 					"Accept-Encoding": []string{"*"}, | ||||||
|  | 					"User-Agent":      []string{"test-agent"}, | ||||||
|  | 					"Foo":             []string{"bar, baz"}, | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					b = b.WithHeaders(replacementHeaders) | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 					Method:     "GET", | ||||||
|  | 					Header:     replacementHeaders, | ||||||
|  | 					Body:       []byte{}, | ||||||
|  | 					RequestURI: "/json/path", | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("replacing the header", func() { | ||||||
|  | 			replacementHeaders := http.Header{ | ||||||
|  | 				"Accept-Encoding": []string{"*"}, | ||||||
|  | 				"User-Agent":      []string{"test-agent"}, | ||||||
|  | 				"Foo":             []string{"bar, baz"}, | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				b = b.WithHeaders(replacementHeaders) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 				Method:     "GET", | ||||||
|  | 				Header:     replacementHeaders, | ||||||
|  | 				Body:       []byte{}, | ||||||
|  | 				RequestURI: "/json/path", | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			Context("then setting a header", func() { | ||||||
|  | 				header := replacementHeaders.Clone() | ||||||
|  | 				header.Set("User-Agent", "different-agent") | ||||||
|  | 
 | ||||||
|  | 				BeforeEach(func() { | ||||||
|  | 					b = b.SetHeader("User-Agent", "different-agent") | ||||||
|  | 				}) | ||||||
|  | 
 | ||||||
|  | 				assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 					Method:     "GET", | ||||||
|  | 					Header:     header, | ||||||
|  | 					Body:       []byte{}, | ||||||
|  | 					RequestURI: "/json/path", | ||||||
|  | 				}) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("if the request has been completed and then modified", func() { | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			result := b.Do() | ||||||
|  | 			Expect(result.Error()).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			b.WithMethod("POST") | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("should not redo the request", func() { | ||||||
|  | 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||||
|  | 				Method:     "GET", | ||||||
|  | 				Header:     baseHeaders, | ||||||
|  | 				Body:       []byte{}, | ||||||
|  | 				RequestURI: "/json/path", | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("when the requested page is not found", func() { | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			b = New(serverAddr + "/not-found") | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		assertJSONError(getBuilder, "404 page not found") | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("when the requested page is not valid JSON", func() { | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			b = New(serverAddr + "/string/path") | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		assertJSONError(getBuilder, "invalid character 'O' looking for beginning of value") | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | 
 | ||||||
|  | func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) { | ||||||
|  | 	Context("Do", func() { | ||||||
|  | 		var result Result | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			result = builder().Do() | ||||||
|  | 			Expect(result.Error()).ToNot(HaveOccurred()) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("returns a successful status", func() { | ||||||
|  | 			Expect(result.StatusCode()).To(Equal(http.StatusOK)) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("made the expected request", func() { | ||||||
|  | 			actualRequest := testHTTPRequest{} | ||||||
|  | 			Expect(json.Unmarshal(result.Body(), &actualRequest)).To(Succeed()) | ||||||
|  | 
 | ||||||
|  | 			Expect(actualRequest).To(Equal(expectedRequest)) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UnmarshalInto", func() { | ||||||
|  | 		var actualRequest testHTTPRequest | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			Expect(builder().Do().UnmarshalInto(&actualRequest)).To(Succeed()) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("made the expected request", func() { | ||||||
|  | 			Expect(actualRequest).To(Equal(expectedRequest)) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UnmarshalJSON", func() { | ||||||
|  | 		var response *simplejson.Json | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			var err error | ||||||
|  | 			response, err = builder().Do().UnmarshalJSON() | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("made the expected reqest", func() { | ||||||
|  | 			header := http.Header{} | ||||||
|  | 			for key, value := range response.Get("Header").MustMap() { | ||||||
|  | 				vs, ok := value.([]interface{}) | ||||||
|  | 				Expect(ok).To(BeTrue()) | ||||||
|  | 				svs := []string{} | ||||||
|  | 				for _, v := range vs { | ||||||
|  | 					sv, ok := v.(string) | ||||||
|  | 					Expect(ok).To(BeTrue()) | ||||||
|  | 					svs = append(svs, sv) | ||||||
|  | 				} | ||||||
|  | 				header[key] = svs | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			// Other json unmarhsallers base64 decode byte slices automatically
 | ||||||
|  | 			body, err := base64.StdEncoding.DecodeString(response.Get("Body").MustString()) | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			actualRequest := testHTTPRequest{ | ||||||
|  | 				Method:     response.Get("Method").MustString(), | ||||||
|  | 				Header:     header, | ||||||
|  | 				Body:       body, | ||||||
|  | 				RequestURI: response.Get("RequestURI").MustString(), | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			Expect(actualRequest).To(Equal(expectedRequest)) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func assertRequestError(builder func() Builder, errorMessage string) { | ||||||
|  | 	Context("Do", func() { | ||||||
|  | 		It("returns an error", func() { | ||||||
|  | 			result := builder().Do() | ||||||
|  | 			Expect(result.Error()).To(MatchError(ContainSubstring(errorMessage))) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UnmarshalInto", func() { | ||||||
|  | 		It("returns an error", func() { | ||||||
|  | 			var actualRequest testHTTPRequest | ||||||
|  | 			err := builder().Do().UnmarshalInto(&actualRequest) | ||||||
|  | 			Expect(err).To(MatchError(ContainSubstring(errorMessage))) | ||||||
|  | 
 | ||||||
|  | 			// Should be empty
 | ||||||
|  | 			Expect(actualRequest).To(Equal(testHTTPRequest{})) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UnmarshalJSON", func() { | ||||||
|  | 		It("returns an error", func() { | ||||||
|  | 			resp, err := builder().Do().UnmarshalJSON() | ||||||
|  | 			Expect(err).To(MatchError(ContainSubstring(errorMessage))) | ||||||
|  | 			Expect(resp).To(BeNil()) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func assertJSONError(builder func() Builder, errorMessage string) { | ||||||
|  | 	Context("Do", func() { | ||||||
|  | 		It("does not return an error", func() { | ||||||
|  | 			result := builder().Do() | ||||||
|  | 			Expect(result.Error()).To(BeNil()) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UnmarshalInto", func() { | ||||||
|  | 		It("returns an error", func() { | ||||||
|  | 			var actualRequest testHTTPRequest | ||||||
|  | 			err := builder().Do().UnmarshalInto(&actualRequest) | ||||||
|  | 			Expect(err).To(MatchError(ContainSubstring(errorMessage))) | ||||||
|  | 
 | ||||||
|  | 			// Should be empty
 | ||||||
|  | 			Expect(actualRequest).To(Equal(testHTTPRequest{})) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UnmarshalJSON", func() { | ||||||
|  | 		It("returns an error", func() { | ||||||
|  | 			resp, err := builder().Do().UnmarshalJSON() | ||||||
|  | 			Expect(err).To(MatchError(ContainSubstring(errorMessage))) | ||||||
|  | 			Expect(resp).To(BeNil()) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  | @ -1,74 +0,0 @@ | ||||||
| package requests |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"fmt" |  | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" |  | ||||||
| 
 |  | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // Request parses the request body into a simplejson.Json object
 |  | ||||||
| func Request(req *http.Request) (*simplejson.Json, error) { |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Printf("%s %s %s", req.Method, req.URL, err) |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	body, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 	if body != nil { |  | ||||||
| 		defer resp.Body.Close() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) |  | ||||||
| 
 |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("problem reading http request body: %w", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		return nil, fmt.Errorf("got %d %s", resp.StatusCode, body) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	data, err := simplejson.NewJson(body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("error unmarshalling json: %w", err) |  | ||||||
| 	} |  | ||||||
| 	return data, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // RequestJSON parses the request body into the given interface
 |  | ||||||
| func RequestJSON(req *http.Request, v interface{}) error { |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Printf("%s %s %s", req.Method, req.URL, err) |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
| 	body, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 	if body != nil { |  | ||||||
| 		defer resp.Body.Close() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("error reading body from http response: %w", err) |  | ||||||
| 	} |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		return fmt.Errorf("got %d %s", resp.StatusCode, body) |  | ||||||
| 	} |  | ||||||
| 	return json.Unmarshal(body, v) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // RequestUnparsedResponse performs a GET and returns the raw response object
 |  | ||||||
| func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) { |  | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", url, nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("error performing get request: %w", err) |  | ||||||
| 	} |  | ||||||
| 	req.Header = header |  | ||||||
| 
 |  | ||||||
| 	return http.DefaultClient.Do(req) |  | ||||||
| } |  | ||||||
|  | @ -0,0 +1,96 @@ | ||||||
|  | package requests | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"log" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	server     *httptest.Server | ||||||
|  | 	serverAddr string | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestRequetsSuite(t *testing.T) { | ||||||
|  | 	logger.SetOutput(GinkgoWriter) | ||||||
|  | 	log.SetOutput(GinkgoWriter) | ||||||
|  | 
 | ||||||
|  | 	RegisterFailHandler(Fail) | ||||||
|  | 	RunSpecs(t, "Requests Suite") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var _ = BeforeSuite(func() { | ||||||
|  | 	// Set up a webserver that reflects requests
 | ||||||
|  | 	mux := http.NewServeMux() | ||||||
|  | 	mux.Handle("/json/", &testHTTPUpstream{}) | ||||||
|  | 	mux.HandleFunc("/string/", func(rw http.ResponseWriter, _ *http.Request) { | ||||||
|  | 		rw.Write([]byte("OK")) | ||||||
|  | 	}) | ||||||
|  | 	server = httptest.NewServer(mux) | ||||||
|  | 	serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String()) | ||||||
|  | }) | ||||||
|  | 
 | ||||||
|  | var _ = AfterSuite(func() { | ||||||
|  | 	server.Close() | ||||||
|  | }) | ||||||
|  | 
 | ||||||
|  | // testHTTPRequest is a struct used to capture the state of a request made to
 | ||||||
|  | // the test server
 | ||||||
|  | type testHTTPRequest struct { | ||||||
|  | 	Method     string | ||||||
|  | 	Header     http.Header | ||||||
|  | 	Body       []byte | ||||||
|  | 	RequestURI string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type testHTTPUpstream struct{} | ||||||
|  | 
 | ||||||
|  | func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 	request, err := toTestHTTPRequest(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.writeError(rw, err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	data, err := json.Marshal(request) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.writeError(rw, err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rw.Header().Set("Content-Type", "application/json") | ||||||
|  | 	rw.Write(data) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) { | ||||||
|  | 	rw.WriteHeader(500) | ||||||
|  | 	if err != nil { | ||||||
|  | 		rw.Write([]byte(err.Error())) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) { | ||||||
|  | 	requestBody := []byte{} | ||||||
|  | 	if req.Body != http.NoBody { | ||||||
|  | 		var err error | ||||||
|  | 		requestBody, err = ioutil.ReadAll(req.Body) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return testHTTPRequest{}, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return testHTTPRequest{ | ||||||
|  | 		Method:     req.Method, | ||||||
|  | 		Header:     req.Header, | ||||||
|  | 		Body:       requestBody, | ||||||
|  | 		RequestURI: req.RequestURI, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  | @ -1,136 +0,0 @@ | ||||||
| package requests |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/http/httptest" |  | ||||||
| 	"strings" |  | ||||||
| 	"testing" |  | ||||||
| 
 |  | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/stretchr/testify/assert" |  | ||||||
| 	"github.com/stretchr/testify/require" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func testBackend(t *testing.T, responseCode int, payload string) *httptest.Server { |  | ||||||
| 	return httptest.NewServer(http.HandlerFunc( |  | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 			w.WriteHeader(responseCode) |  | ||||||
| 			_, err := w.Write([]byte(payload)) |  | ||||||
| 			require.NoError(t, err) |  | ||||||
| 		})) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestRequest(t *testing.T) { |  | ||||||
| 	backend := testBackend(t, 200, "{\"foo\": \"bar\"}") |  | ||||||
| 	defer backend.Close() |  | ||||||
| 
 |  | ||||||
| 	req, _ := http.NewRequest("GET", backend.URL, nil) |  | ||||||
| 	response, err := Request(req) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	result, err := response.Get("foo").String() |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	assert.Equal(t, "bar", result) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestRequestFailure(t *testing.T) { |  | ||||||
| 	// Create a backend to generate a test URL, then close it to cause a
 |  | ||||||
| 	// connection error.
 |  | ||||||
| 	backend := testBackend(t, 200, "{\"foo\": \"bar\"}") |  | ||||||
| 	backend.Close() |  | ||||||
| 
 |  | ||||||
| 	req, err := http.NewRequest("GET", backend.URL, nil) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	resp, err := Request(req) |  | ||||||
| 	assert.Equal(t, (*simplejson.Json)(nil), resp) |  | ||||||
| 	assert.NotEqual(t, nil, err) |  | ||||||
| 	if !strings.Contains(err.Error(), "refused") { |  | ||||||
| 		t.Error("expected error when a connection fails: ", err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestHttpErrorCode(t *testing.T) { |  | ||||||
| 	backend := testBackend(t, 404, "{\"foo\": \"bar\"}") |  | ||||||
| 	defer backend.Close() |  | ||||||
| 
 |  | ||||||
| 	req, err := http.NewRequest("GET", backend.URL, nil) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	resp, err := Request(req) |  | ||||||
| 	assert.Equal(t, (*simplejson.Json)(nil), resp) |  | ||||||
| 	assert.NotEqual(t, nil, err) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestJsonParsingError(t *testing.T) { |  | ||||||
| 	backend := testBackend(t, 200, "not well-formed JSON") |  | ||||||
| 	defer backend.Close() |  | ||||||
| 
 |  | ||||||
| 	req, err := http.NewRequest("GET", backend.URL, nil) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	resp, err := Request(req) |  | ||||||
| 	assert.Equal(t, (*simplejson.Json)(nil), resp) |  | ||||||
| 	assert.NotEqual(t, nil, err) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Parsing a URL practically never fails, so we won't cover that test case.
 |  | ||||||
| func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { |  | ||||||
| 	backend := httptest.NewServer(http.HandlerFunc( |  | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 			token := r.FormValue("access_token") |  | ||||||
| 			if r.URL.Path == "/" && token == "my_token" { |  | ||||||
| 				w.WriteHeader(200) |  | ||||||
| 				_, err := w.Write([]byte("some payload")) |  | ||||||
| 				require.NoError(t, err) |  | ||||||
| 			} else { |  | ||||||
| 				w.WriteHeader(403) |  | ||||||
| 			} |  | ||||||
| 		})) |  | ||||||
| 	defer backend.Close() |  | ||||||
| 
 |  | ||||||
| 	response, err := RequestUnparsedResponse( |  | ||||||
| 		context.Background(), backend.URL+"?access_token=my_token", nil) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	defer response.Body.Close() |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, 200, response.StatusCode) |  | ||||||
| 	body, err := ioutil.ReadAll(response.Body) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	assert.Equal(t, "some payload", string(body)) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) { |  | ||||||
| 	backend := testBackend(t, 200, "some payload") |  | ||||||
| 	// Close the backend now to force a request failure.
 |  | ||||||
| 	backend.Close() |  | ||||||
| 
 |  | ||||||
| 	response, err := RequestUnparsedResponse( |  | ||||||
| 		context.Background(), backend.URL+"?access_token=my_token", nil) |  | ||||||
| 	assert.NotEqual(t, nil, err) |  | ||||||
| 	assert.Equal(t, (*http.Response)(nil), response) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { |  | ||||||
| 	backend := httptest.NewServer(http.HandlerFunc( |  | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { |  | ||||||
| 			if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" { |  | ||||||
| 				w.WriteHeader(200) |  | ||||||
| 				_, err := w.Write([]byte("some payload")) |  | ||||||
| 				require.NoError(t, err) |  | ||||||
| 			} else { |  | ||||||
| 				w.WriteHeader(403) |  | ||||||
| 			} |  | ||||||
| 		})) |  | ||||||
| 	defer backend.Close() |  | ||||||
| 
 |  | ||||||
| 	headers := make(http.Header) |  | ||||||
| 	headers.Set("Auth", "my_token") |  | ||||||
| 	response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	defer response.Body.Close() |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, 200, response.StatusCode) |  | ||||||
| 	body, err := ioutil.ReadAll(response.Body) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 
 |  | ||||||
| 	assert.Equal(t, "some payload", string(body)) |  | ||||||
| } |  | ||||||
|  | @ -0,0 +1,98 @@ | ||||||
|  | package requests | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/bitly/go-simplejson" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Result is the result of a request created by a Builder
 | ||||||
|  | type Result interface { | ||||||
|  | 	Error() error | ||||||
|  | 	StatusCode() int | ||||||
|  | 	Headers() http.Header | ||||||
|  | 	Body() []byte | ||||||
|  | 	UnmarshalInto(interface{}) error | ||||||
|  | 	UnmarshalJSON() (*simplejson.Json, error) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type result struct { | ||||||
|  | 	err      error | ||||||
|  | 	response *http.Response | ||||||
|  | 	body     []byte | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Error returns an error from the result if present
 | ||||||
|  | func (r *result) Error() error { | ||||||
|  | 	return r.err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // StatusCode returns the response's status code
 | ||||||
|  | func (r *result) StatusCode() int { | ||||||
|  | 	if r.response != nil { | ||||||
|  | 		return r.response.StatusCode | ||||||
|  | 	} | ||||||
|  | 	return 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Headers returns the response's headers
 | ||||||
|  | func (r *result) Headers() http.Header { | ||||||
|  | 	if r.response != nil { | ||||||
|  | 		return r.response.Header | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Body returns the response's body
 | ||||||
|  | func (r *result) Body() []byte { | ||||||
|  | 	return r.body | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // UnmarshalInto attempts to unmarshal the response into the the given interface.
 | ||||||
|  | // The response body is assumed to be JSON.
 | ||||||
|  | // The response must have a 200 status otherwise an error will be returned.
 | ||||||
|  | func (r *result) UnmarshalInto(into interface{}) error { | ||||||
|  | 	body, err := r.getBodyForUnmarshal() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := json.Unmarshal(body, into); err != nil { | ||||||
|  | 		return fmt.Errorf("error unmarshalling body: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // UnmarshalJSON performs the request and attempts to unmarshal the response into a
 | ||||||
|  | // simplejson.Json. The response body is assume to be JSON.
 | ||||||
|  | // The response must have a 200 status otherwise an error will be returned.
 | ||||||
|  | func (r *result) UnmarshalJSON() (*simplejson.Json, error) { | ||||||
|  | 	body, err := r.getBodyForUnmarshal() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	data, err := simplejson.NewJson(body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error reading json: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return data, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getBodyForUnmarshal returns the body if there wasn't an error and the status
 | ||||||
|  | // code was 200.
 | ||||||
|  | func (r *result) getBodyForUnmarshal() ([]byte, error) { | ||||||
|  | 	if r.Error() != nil { | ||||||
|  | 		return nil, r.Error() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Only unmarshal body if the response was successful
 | ||||||
|  | 	if r.StatusCode() != http.StatusOK { | ||||||
|  | 		return nil, fmt.Errorf("unexpected status \"%d\": %s", r.StatusCode(), r.Body()) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return r.Body(), nil | ||||||
|  | } | ||||||
|  | @ -0,0 +1,326 @@ | ||||||
|  | package requests | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Result suite", func() { | ||||||
|  | 	Context("with a result", func() { | ||||||
|  | 		type resultTableInput struct { | ||||||
|  | 			result             Result | ||||||
|  | 			expectedError      error | ||||||
|  | 			expectedStatusCode int | ||||||
|  | 			expectedHeaders    http.Header | ||||||
|  | 			expectedBody       []byte | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("accessors should return expected results", | ||||||
|  | 			func(in resultTableInput) { | ||||||
|  | 				if in.expectedError != nil { | ||||||
|  | 					Expect(in.result.Error()).To(MatchError(in.expectedError)) | ||||||
|  | 				} else { | ||||||
|  | 					Expect(in.result.Error()).To(BeNil()) | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				Expect(in.result.StatusCode()).To(Equal(in.expectedStatusCode)) | ||||||
|  | 				Expect(in.result.Headers()).To(Equal(in.expectedHeaders)) | ||||||
|  | 				Expect(in.result.Body()).To(Equal(in.expectedBody)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("with an empty result", resultTableInput{ | ||||||
|  | 				result:             &result{}, | ||||||
|  | 				expectedError:      nil, | ||||||
|  | 				expectedStatusCode: 0, | ||||||
|  | 				expectedHeaders:    nil, | ||||||
|  | 				expectedBody:       nil, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an error", resultTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: errors.New("error"), | ||||||
|  | 				}, | ||||||
|  | 				expectedError:      errors.New("error"), | ||||||
|  | 				expectedStatusCode: 0, | ||||||
|  | 				expectedHeaders:    nil, | ||||||
|  | 				expectedBody:       nil, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a response with no headers", resultTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusTeapot, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedError:      nil, | ||||||
|  | 				expectedStatusCode: http.StatusTeapot, | ||||||
|  | 				expectedHeaders:    nil, | ||||||
|  | 				expectedBody:       nil, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a response with no status code", resultTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						Header: http.Header{ | ||||||
|  | 							"foo": []string{"bar"}, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedError:      nil, | ||||||
|  | 				expectedStatusCode: 0, | ||||||
|  | 				expectedHeaders: http.Header{ | ||||||
|  | 					"foo": []string{"bar"}, | ||||||
|  | 				}, | ||||||
|  | 				expectedBody: nil, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a response with a body", resultTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					body: []byte("some body"), | ||||||
|  | 				}, | ||||||
|  | 				expectedError:      nil, | ||||||
|  | 				expectedStatusCode: 0, | ||||||
|  | 				expectedHeaders:    nil, | ||||||
|  | 				expectedBody:       []byte("some body"), | ||||||
|  | 			}), | ||||||
|  | 			Entry("with all fields", resultTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: errors.New("some error"), | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusFound, | ||||||
|  | 						Header: http.Header{ | ||||||
|  | 							"header": []string{"value"}, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("a body"), | ||||||
|  | 				}, | ||||||
|  | 				expectedError:      errors.New("some error"), | ||||||
|  | 				expectedStatusCode: http.StatusFound, | ||||||
|  | 				expectedHeaders: http.Header{ | ||||||
|  | 					"header": []string{"value"}, | ||||||
|  | 				}, | ||||||
|  | 				expectedBody: []byte("a body"), | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UnmarshalInto", func() { | ||||||
|  | 		type testStruct struct { | ||||||
|  | 			A string `json:"a"` | ||||||
|  | 			B int    `json:"b"` | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		type unmarshalIntoTableInput struct { | ||||||
|  | 			result         Result | ||||||
|  | 			expectedErr    error | ||||||
|  | 			expectedOutput *testStruct | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("with a result", | ||||||
|  | 			func(in unmarshalIntoTableInput) { | ||||||
|  | 				input := &testStruct{} | ||||||
|  | 				err := in.result.UnmarshalInto(input) | ||||||
|  | 				if in.expectedErr != nil { | ||||||
|  | 					Expect(err).To(MatchError(in.expectedErr)) | ||||||
|  | 				} else { | ||||||
|  | 					Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				} | ||||||
|  | 				Expect(input).To(Equal(in.expectedOutput)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("with an error", unmarshalIntoTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: errors.New("got an error"), | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("{\"a\": \"foo\"}"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    errors.New("got an error"), | ||||||
|  | 				expectedOutput: &testStruct{}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a 409 status code", unmarshalIntoTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusConflict, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("{\"a\": \"foo\"}"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    errors.New("unexpected status \"409\": {\"a\": \"foo\"}"), | ||||||
|  | 				expectedOutput: &testStruct{}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the response has a valid json response", unmarshalIntoTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("{\"a\": \"foo\", \"b\": 1}"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    nil, | ||||||
|  | 				expectedOutput: &testStruct{A: "foo", B: 1}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the response body is empty", unmarshalIntoTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte(""), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    errors.New("error unmarshalling body: unexpected end of JSON input"), | ||||||
|  | 				expectedOutput: &testStruct{}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the response body is not json", unmarshalIntoTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("not json"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    errors.New("error unmarshalling body: invalid character 'o' in literal null (expecting 'u')"), | ||||||
|  | 				expectedOutput: &testStruct{}, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("UnmarshalJSON", func() { | ||||||
|  | 		type testStruct struct { | ||||||
|  | 			A string `json:"a"` | ||||||
|  | 			B int    `json:"b"` | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		type unmarshalJSONTableInput struct { | ||||||
|  | 			result         Result | ||||||
|  | 			expectedErr    error | ||||||
|  | 			expectedOutput *testStruct | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("with a result", | ||||||
|  | 			func(in unmarshalJSONTableInput) { | ||||||
|  | 				j, err := in.result.UnmarshalJSON() | ||||||
|  | 				if in.expectedErr != nil { | ||||||
|  | 					Expect(err).To(MatchError(in.expectedErr)) | ||||||
|  | 					Expect(j).To(BeNil()) | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				// No error so j should not be nil
 | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				input := &testStruct{ | ||||||
|  | 					A: j.Get("a").MustString(), | ||||||
|  | 					B: j.Get("b").MustInt(), | ||||||
|  | 				} | ||||||
|  | 				Expect(input).To(Equal(in.expectedOutput)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("with an error", unmarshalJSONTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: errors.New("got an error"), | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("{\"a\": \"foo\"}"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    errors.New("got an error"), | ||||||
|  | 				expectedOutput: &testStruct{}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a 409 status code", unmarshalJSONTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusConflict, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("{\"a\": \"foo\"}"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    errors.New("unexpected status \"409\": {\"a\": \"foo\"}"), | ||||||
|  | 				expectedOutput: &testStruct{}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the response has a valid json response", unmarshalJSONTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("{\"a\": \"foo\", \"b\": 1}"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    nil, | ||||||
|  | 				expectedOutput: &testStruct{A: "foo", B: 1}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the response body is empty", unmarshalJSONTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte(""), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    errors.New("error reading json: EOF"), | ||||||
|  | 				expectedOutput: &testStruct{}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the response body is not json", unmarshalJSONTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("not json"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:    errors.New("error reading json: invalid character 'o' in literal null (expecting 'u')"), | ||||||
|  | 				expectedOutput: &testStruct{}, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("getBodyForUnmarshal", func() { | ||||||
|  | 		type getBodyForUnmarshalTableInput struct { | ||||||
|  | 			result       *result | ||||||
|  | 			expectedErr  error | ||||||
|  | 			expectedBody []byte | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("when getting the body", func(in getBodyForUnmarshalTableInput) { | ||||||
|  | 			body, err := in.result.getBodyForUnmarshal() | ||||||
|  | 			if in.expectedErr != nil { | ||||||
|  | 				Expect(err).To(MatchError(in.expectedErr)) | ||||||
|  | 			} else { | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			} | ||||||
|  | 			Expect(body).To(Equal(in.expectedBody)) | ||||||
|  | 		}, | ||||||
|  | 			Entry("when the result has an error", getBodyForUnmarshalTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: errors.New("got an error"), | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("body"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:  errors.New("got an error"), | ||||||
|  | 				expectedBody: nil, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the response has a 409 status code", getBodyForUnmarshalTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusConflict, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("body"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:  errors.New("unexpected status \"409\": body"), | ||||||
|  | 				expectedBody: nil, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the response has a 200 status code", getBodyForUnmarshalTableInput{ | ||||||
|  | 				result: &result{ | ||||||
|  | 					err: nil, | ||||||
|  | 					response: &http.Response{ | ||||||
|  | 						StatusCode: http.StatusOK, | ||||||
|  | 					}, | ||||||
|  | 					body: []byte("body"), | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:  nil, | ||||||
|  | 				expectedBody: []byte("body"), | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -83,9 +83,14 @@ func Validate(o *options.Options) error { | ||||||
| 
 | 
 | ||||||
| 			logger.Printf("Performing OIDC Discovery...") | 			logger.Printf("Performing OIDC Discovery...") | ||||||
| 
 | 
 | ||||||
| 			if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil { | 			requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" | ||||||
| 				if body, err := requests.Request(req); err == nil { | 			body, err := requests.New(requestURL). | ||||||
| 
 | 				WithContext(ctx). | ||||||
|  | 				Do(). | ||||||
|  | 				UnmarshalJSON() | ||||||
|  | 			if err != nil { | ||||||
|  | 				logger.Printf("error: failed to discover OIDC configuration: %v", err) | ||||||
|  | 			} else { | ||||||
| 				// Prefer manually configured URLs. It's a bit unclear
 | 				// Prefer manually configured URLs. It's a bit unclear
 | ||||||
| 				// why you'd be doing discovery and also providing the URLs
 | 				// why you'd be doing discovery and also providing the URLs
 | ||||||
| 				// explicitly though...
 | 				// explicitly though...
 | ||||||
|  | @ -106,11 +111,6 @@ func Validate(o *options.Options) error { | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				o.SkipOIDCDiscovery = true | 				o.SkipOIDCDiscovery = true | ||||||
| 				} else { |  | ||||||
| 					logger.Printf("error: failed to discover OIDC configuration: %v", err) |  | ||||||
| 				} |  | ||||||
| 			} else { |  | ||||||
| 				logger.Printf("error: failed parsing OIDC discovery URL: %v", err) |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -385,10 +385,10 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// Try as JWKS URI
 | 		// Try as JWKS URI
 | ||||||
| 		jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" | 		jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" | ||||||
| 		_, err := http.NewRequest("GET", jwksURI, nil) | 		if err := requests.New(jwksURI).Do().Error(); err != nil { | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
| 		verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) | 		verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) | ||||||
| 	} else { | 	} else { | ||||||
| 		verifier = provider.Verifier(config) | 		verifier = provider.Verifier(config) | ||||||
|  |  | ||||||
|  | @ -3,10 +3,8 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -91,39 +89,22 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 		params.Add("resource", p.ProtectedResource.String()) | 		params.Add("resource", p.ProtectedResource.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var req *http.Request |  | ||||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |  | ||||||
| 
 |  | ||||||
| 	var resp *http.Response |  | ||||||
| 	resp, err = http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	var body []byte |  | ||||||
| 	body, err = ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var jsonResponse struct { | 	var jsonResponse struct { | ||||||
| 		AccessToken  string `json:"access_token"` | 		AccessToken  string `json:"access_token"` | ||||||
| 		RefreshToken string `json:"refresh_token"` | 		RefreshToken string `json:"refresh_token"` | ||||||
| 		ExpiresOn    int64  `json:"expires_on,string"` | 		ExpiresOn    int64  `json:"expires_on,string"` | ||||||
| 		IDToken      string `json:"id_token"` | 		IDToken      string `json:"id_token"` | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &jsonResponse) | 
 | ||||||
|  | 	err = requests.New(p.RedeemURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithMethod("POST"). | ||||||
|  | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
|  | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&jsonResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() | 	created := time.Now() | ||||||
|  | @ -169,26 +150,22 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	req.Header = getAzureHeader(s.AccessToken) |  | ||||||
| 
 |  | ||||||
| 	json, err := requests.Request(req) |  | ||||||
| 
 | 
 | ||||||
|  | 	json, err := requests.New(p.ProfileURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(getAzureHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	email, err = getEmailFromJSON(json) | 	email, err = getEmailFromJSON(json) | ||||||
| 
 |  | ||||||
| 	if err == nil && email != "" { | 	if err == nil && email != "" { | ||||||
| 		return email, err | 		return email, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	email, err = json.Get("userPrincipalName").String() | 	email, err = json.Get("userPrincipalName").String() | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("failed making request %s", err) | 		logger.Printf("failed making request %s", err) | ||||||
| 		return "", err | 		return "", err | ||||||
|  |  | ||||||
|  | @ -2,7 +2,6 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
|  | @ -85,15 +84,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | ||||||
| 			FullName string `json:"full_name"` | 			FullName string `json:"full_name"` | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", | 
 | ||||||
| 		p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) | 	requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken | ||||||
|  | 	err := requests.New(requestURL). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&emails) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("failed building request %s", err) | 		logger.Printf("failed making request: %v", err) | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	err = requests.RequestJSON(req, &emails) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Printf("failed making request %s", err) |  | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -101,15 +99,15 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | ||||||
| 		teamURL := &url.URL{} | 		teamURL := &url.URL{} | ||||||
| 		*teamURL = *p.ValidateURL | 		*teamURL = *p.ValidateURL | ||||||
| 		teamURL.Path = "/2.0/teams" | 		teamURL.Path = "/2.0/teams" | ||||||
| 		req, err = http.NewRequestWithContext(ctx, "GET", | 
 | ||||||
| 			teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) | 		requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken | ||||||
|  | 
 | ||||||
|  | 		err := requests.New(requestURL). | ||||||
|  | 			WithContext(ctx). | ||||||
|  | 			Do(). | ||||||
|  | 			UnmarshalInto(&teams) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Printf("failed building request %s", err) | 			logger.Printf("failed requesting teams membership: %v", err) | ||||||
| 			return "", err |  | ||||||
| 		} |  | ||||||
| 		err = requests.RequestJSON(req, &teams) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Printf("failed requesting teams membership %s", err) |  | ||||||
| 			return "", err | 			return "", err | ||||||
| 		} | 		} | ||||||
| 		var found = false | 		var found = false | ||||||
|  | @ -129,20 +127,20 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | ||||||
| 		repositoriesURL := &url.URL{} | 		repositoriesURL := &url.URL{} | ||||||
| 		*repositoriesURL = *p.ValidateURL | 		*repositoriesURL = *p.ValidateURL | ||||||
| 		repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] | 		repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] | ||||||
| 		req, err = http.NewRequestWithContext(ctx, "GET", | 
 | ||||||
| 			repositoriesURL.String()+"?role=contributor"+ | 		requestURL := repositoriesURL.String() + "?role=contributor" + | ||||||
| 			"&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") + | 			"&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") + | ||||||
| 				"&access_token="+s.AccessToken, | 			"&access_token=" + s.AccessToken | ||||||
| 			nil) | 
 | ||||||
|  | 		err := requests.New(requestURL). | ||||||
|  | 			WithContext(ctx). | ||||||
|  | 			Do(). | ||||||
|  | 			UnmarshalInto(&repositories) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			logger.Printf("failed building request %s", err) | 			logger.Printf("failed checking repository access: %v", err) | ||||||
| 			return "", err |  | ||||||
| 		} |  | ||||||
| 		err = requests.RequestJSON(req, &repositories) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Printf("failed checking repository access %s", err) |  | ||||||
| 			return "", err | 			return "", err | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
| 		var found = false | 		var found = false | ||||||
| 		for _, repository := range repositories.Values { | 		for _, repository := range repositories.Values { | ||||||
| 			if p.Repository == repository.FullName { | 			if p.Repository == repository.FullName { | ||||||
|  |  | ||||||
|  | @ -60,13 +60,12 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	req.Header = getDigitalOceanHeader(s.AccessToken) |  | ||||||
| 
 | 
 | ||||||
| 	json, err := requests.Request(req) | 	json, err := requests.New(p.ProfileURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(getDigitalOceanHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -62,20 +62,22 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	req.Header = getFacebookHeader(s.AccessToken) |  | ||||||
| 
 | 
 | ||||||
| 	type result struct { | 	type result struct { | ||||||
| 		Email string | 		Email string | ||||||
| 	} | 	} | ||||||
| 	var r result | 	var r result | ||||||
| 	err = requests.RequestJSON(req, &r) | 
 | ||||||
|  | 	requestURL := p.ProfileURL.String() + "?fields=name,email" | ||||||
|  | 	err := requests.New(requestURL). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(getFacebookHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&r) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	if r.Email == "" { | 	if r.Email == "" { | ||||||
| 		return "", errors.New("no email") | 		return "", errors.New("no email") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -2,10 +2,8 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"path" | 	"path" | ||||||
|  | @ -15,6 +13,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // GitHubProvider represents an GitHub based Identity Provider
 | // GitHubProvider represents an GitHub based Identity Provider
 | ||||||
|  | @ -111,27 +110,17 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, | ||||||
| 			Path:     path.Join(p.ValidateURL.Path, "/user/orgs"), | 			Path:     path.Join(p.ValidateURL.Path, "/user/orgs"), | ||||||
| 			RawQuery: params.Encode(), | 			RawQuery: params.Encode(), | ||||||
| 		} | 		} | ||||||
| 		req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) |  | ||||||
| 		req.Header = getGitHubHeader(accessToken) |  | ||||||
| 		resp, err := http.DefaultClient.Do(req) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return false, err |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		body, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 		resp.Body.Close() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return false, err |  | ||||||
| 		} |  | ||||||
| 		if resp.StatusCode != 200 { |  | ||||||
| 			return false, fmt.Errorf( |  | ||||||
| 				"got %d from %q %s", resp.StatusCode, endpoint.String(), body) |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		var op orgsPage | 		var op orgsPage | ||||||
| 		if err := json.Unmarshal(body, &op); err != nil { | 		err := requests.New(endpoint.String()). | ||||||
|  | 			WithContext(ctx). | ||||||
|  | 			WithHeaders(getGitHubHeader(accessToken)). | ||||||
|  | 			Do(). | ||||||
|  | 			UnmarshalInto(&op) | ||||||
|  | 		if err != nil { | ||||||
| 			return false, err | 			return false, err | ||||||
| 		} | 		} | ||||||
|  | 
 | ||||||
| 		if len(op) == 0 { | 		if len(op) == 0 { | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
|  | @ -187,11 +176,15 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | ||||||
| 			RawQuery: params.Encode(), | 			RawQuery: params.Encode(), | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | 		// bodyclose cannot detect that the body is being closed later in requests.Into,
 | ||||||
| 		req.Header = getGitHubHeader(accessToken) | 		// so have to skip the linting for the next line.
 | ||||||
| 		resp, err := http.DefaultClient.Do(req) | 		// nolint:bodyclose
 | ||||||
| 		if err != nil { | 		result := requests.New(endpoint.String()). | ||||||
| 			return false, err | 			WithContext(ctx). | ||||||
|  | 			WithHeaders(getGitHubHeader(accessToken)). | ||||||
|  | 			Do() | ||||||
|  | 		if result.Error() != nil { | ||||||
|  | 			return false, result.Error() | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if last == 0 { | 		if last == 0 { | ||||||
|  | @ -207,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | ||||||
| 			// link header at last page (doesn't exist last info)
 | 			// link header at last page (doesn't exist last info)
 | ||||||
| 			// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
 | 			// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
 | ||||||
| 
 | 
 | ||||||
| 			link := resp.Header.Get("Link") | 			link := result.Headers().Get("Link") | ||||||
| 			rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`) | 			rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`) | ||||||
| 			i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) | 			i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) | ||||||
| 
 | 
 | ||||||
|  | @ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		body, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 		if err != nil { |  | ||||||
| 			resp.Body.Close() |  | ||||||
| 			return false, err |  | ||||||
| 		} |  | ||||||
| 		resp.Body.Close() |  | ||||||
| 
 |  | ||||||
| 		if resp.StatusCode != 200 { |  | ||||||
| 			return false, fmt.Errorf( |  | ||||||
| 				"got %d from %q %s", resp.StatusCode, endpoint.String(), body) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		var tp teamsPage | 		var tp teamsPage | ||||||
| 		if err := json.Unmarshal(body, &tp); err != nil { | 		if err := result.UnmarshalInto(&tp); err != nil { | ||||||
| 			return false, fmt.Errorf("%s unmarshaling %s", err, body) | 			return false, err | ||||||
| 		} | 		} | ||||||
| 		if len(tp) == 0 { | 		if len(tp) == 0 { | ||||||
| 			break | 			break | ||||||
|  | @ -297,25 +278,13 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, | ||||||
| 		Path:   path.Join(p.ValidateURL.Path, "/repo/", p.Repo), | 		Path:   path.Join(p.ValidateURL.Path, "/repo/", p.Repo), | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) |  | ||||||
| 	req.Header = getGitHubHeader(accessToken) |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	body, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		return false, fmt.Errorf( |  | ||||||
| 			"got %d from %q %s", resp.StatusCode, endpoint.String(), body) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var repo repository | 	var repo repository | ||||||
| 	if err := json.Unmarshal(body, &repo); err != nil { | 	err := requests.New(endpoint.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(getGitHubHeader(accessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&repo) | ||||||
|  | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -337,26 +306,15 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, | ||||||
| 		Host:   p.ValidateURL.Host, | 		Host:   p.ValidateURL.Host, | ||||||
| 		Path:   path.Join(p.ValidateURL.Path, "/user"), | 		Path:   path.Join(p.ValidateURL.Path, "/user"), | ||||||
| 	} | 	} | ||||||
| 	req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) |  | ||||||
| 	req.Header = getGitHubHeader(accessToken) |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| 	defer resp.Body.Close() |  | ||||||
| 
 | 
 | ||||||
| 	body, err := ioutil.ReadAll(resp.Body) | 	err := requests.New(endpoint.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(getGitHubHeader(accessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&user) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		return false, fmt.Errorf("got %d from %q %s", |  | ||||||
| 			resp.StatusCode, stripToken(endpoint.String()), body) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if err := json.Unmarshal(body, &user); err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	if p.isVerifiedUser(user.Login) { | 	if p.isVerifiedUser(user.Login) { | ||||||
| 		return true, nil | 		return true, nil | ||||||
|  | @ -372,24 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok | ||||||
| 		Host:   p.ValidateURL.Host, | 		Host:   p.ValidateURL.Host, | ||||||
| 		Path:   path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), | 		Path:   path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), | ||||||
| 	} | 	} | ||||||
| 	req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | 	result := requests.New(endpoint.String()). | ||||||
| 	req.Header = getGitHubHeader(accessToken) | 		WithContext(ctx). | ||||||
| 	resp, err := http.DefaultClient.Do(req) | 		WithHeaders(getGitHubHeader(accessToken)). | ||||||
| 	if err != nil { | 		Do() | ||||||
| 		return false, err | 	if result.Error() != nil { | ||||||
| 	} | 		return false, result.Error() | ||||||
| 	body, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if resp.StatusCode != 204 { | 	if result.StatusCode() != 204 { | ||||||
| 		return false, fmt.Errorf("got %d from %q %s", | 		return false, fmt.Errorf("got %d from %q %s", | ||||||
| 			resp.StatusCode, endpoint.String(), body) | 			result.StatusCode(), endpoint.String(), result.Body()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) | 	logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body()) | ||||||
| 
 | 
 | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
|  | @ -440,28 +394,14 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio | ||||||
| 		Host:   p.ValidateURL.Host, | 		Host:   p.ValidateURL.Host, | ||||||
| 		Path:   path.Join(p.ValidateURL.Path, "/user/emails"), | 		Path:   path.Join(p.ValidateURL.Path, "/user/emails"), | ||||||
| 	} | 	} | ||||||
| 	req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | 	err := requests.New(endpoint.String()). | ||||||
| 	req.Header = getGitHubHeader(s.AccessToken) | 		WithContext(ctx). | ||||||
| 	resp, err := http.DefaultClient.Do(req) | 		WithHeaders(getGitHubHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&emails) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 	body, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		return "", fmt.Errorf("got %d from %q %s", |  | ||||||
| 			resp.StatusCode, endpoint.String(), body) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) |  | ||||||
| 
 |  | ||||||
| 	if err := json.Unmarshal(body, &emails); err != nil { |  | ||||||
| 		return "", fmt.Errorf("%s unmarshaling %s", err, body) |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	returnEmail := "" | 	returnEmail := "" | ||||||
| 	for _, email := range emails { | 	for _, email := range emails { | ||||||
|  | @ -489,34 +429,15 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta | ||||||
| 		Path:   path.Join(p.ValidateURL.Path, "/user"), | 		Path:   path.Join(p.ValidateURL.Path, "/user"), | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | 	err := requests.New(endpoint.String()). | ||||||
| 	if err != nil { | 		WithContext(ctx). | ||||||
| 		return "", fmt.Errorf("could not create new GET request: %v", err) | 		WithHeaders(getGitHubHeader(s.AccessToken)). | ||||||
| 	} | 		Do(). | ||||||
| 
 | 		UnmarshalInto(&user) | ||||||
| 	req.Header = getGitHubHeader(s.AccessToken) |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	body, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 	defer resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		return "", fmt.Errorf("got %d from %q %s", |  | ||||||
| 			resp.StatusCode, endpoint.String(), body) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) |  | ||||||
| 
 |  | ||||||
| 	if err := json.Unmarshal(body, &user); err != nil { |  | ||||||
| 		return "", fmt.Errorf("%s unmarshaling %s", err, body) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Now that we have the username we can check collaborator status
 | 	// Now that we have the username we can check collaborator status
 | ||||||
| 	if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" { | 	if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" { | ||||||
| 		if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok { | 		if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok { | ||||||
|  |  | ||||||
|  | @ -2,15 +2,13 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	oidc "github.com/coreos/go-oidc" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -131,31 +129,14 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta | ||||||
| 	userInfoURL := *p.LoginURL | 	userInfoURL := *p.LoginURL | ||||||
| 	userInfoURL.Path = "/oauth/userinfo" | 	userInfoURL.Path = "/oauth/userinfo" | ||||||
| 
 | 
 | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to create user info request: %v", err) |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Authorization", "Bearer "+s.AccessToken) |  | ||||||
| 
 |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to perform user info request: %v", err) |  | ||||||
| 	} |  | ||||||
| 	var body []byte |  | ||||||
| 	body, err = ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to read user info response: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var userInfo gitlabUserInfo | 	var userInfo gitlabUserInfo | ||||||
| 	err = json.Unmarshal(body, &userInfo) | 	err := requests.New(userInfoURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		SetHeader("Authorization", "Bearer "+s.AccessToken). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&userInfo) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to parse user info: %v", err) | 		return nil, fmt.Errorf("error getting user info: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &userInfo, nil | 	return &userInfo, nil | ||||||
|  |  | ||||||
|  | @ -9,13 +9,13 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||||
| 	"golang.org/x/oauth2/google" | 	"golang.org/x/oauth2/google" | ||||||
| 	admin "google.golang.org/api/admin/directory/v1" | 	admin "google.golang.org/api/admin/directory/v1" | ||||||
| 	"google.golang.org/api/googleapi" | 	"google.golang.org/api/googleapi" | ||||||
|  | @ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | ||||||
| 	params.Add("client_secret", clientSecret) | 	params.Add("client_secret", clientSecret) | ||||||
| 	params.Add("code", code) | 	params.Add("code", code) | ||||||
| 	params.Add("grant_type", "authorization_code") | 	params.Add("grant_type", "authorization_code") | ||||||
| 	var req *http.Request |  | ||||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |  | ||||||
| 
 |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	var body []byte |  | ||||||
| 	body, err = ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	var jsonResponse struct { | 	var jsonResponse struct { | ||||||
| 		AccessToken  string `json:"access_token"` | 		AccessToken  string `json:"access_token"` | ||||||
|  | @ -145,10 +123,18 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | ||||||
| 		ExpiresIn    int64  `json:"expires_in"` | 		ExpiresIn    int64  `json:"expires_in"` | ||||||
| 		IDToken      string `json:"id_token"` | 		IDToken      string `json:"id_token"` | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &jsonResponse) | 
 | ||||||
|  | 	err = requests.New(p.RedeemURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithMethod("POST"). | ||||||
|  | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
|  | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&jsonResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	c, err := claimsFromIDToken(jsonResponse.IDToken) | 	c, err := claimsFromIDToken(jsonResponse.IDToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
|  | @ -283,38 +269,24 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st | ||||||
| 	params.Add("client_secret", clientSecret) | 	params.Add("client_secret", clientSecret) | ||||||
| 	params.Add("refresh_token", refreshToken) | 	params.Add("refresh_token", refreshToken) | ||||||
| 	params.Add("grant_type", "refresh_token") | 	params.Add("grant_type", "refresh_token") | ||||||
| 	var req *http.Request |  | ||||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |  | ||||||
| 
 |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	var body []byte |  | ||||||
| 	body, err = ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	var data struct { | 	var data struct { | ||||||
| 		AccessToken string `json:"access_token"` | 		AccessToken string `json:"access_token"` | ||||||
| 		ExpiresIn   int64  `json:"expires_in"` | 		ExpiresIn   int64  `json:"expires_in"` | ||||||
| 		IDToken     string `json:"id_token"` | 		IDToken     string `json:"id_token"` | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &data) | 
 | ||||||
|  | 	err = requests.New(p.RedeemURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithMethod("POST"). | ||||||
|  | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
|  | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&data) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return "", "", 0, err | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	token = data.AccessToken | 	token = data.AccessToken | ||||||
| 	idToken = data.IDToken | 	idToken = data.IDToken | ||||||
| 	expires = time.Duration(data.ExpiresIn) * time.Second | 	expires = time.Duration(data.ExpiresIn) * time.Second | ||||||
|  |  | ||||||
|  | @ -2,7 +2,6 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
|  | @ -56,20 +55,22 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h | ||||||
| 		params := url.Values{"access_token": {accessToken}} | 		params := url.Values{"access_token": {accessToken}} | ||||||
| 		endpoint = endpoint + "?" + params.Encode() | 		endpoint = endpoint + "?" + params.Encode() | ||||||
| 	} | 	} | ||||||
| 	resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header) | 
 | ||||||
| 	if err != nil { | 	result := requests.New(endpoint). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(header). | ||||||
|  | 		Do() | ||||||
|  | 	if result.Error() != nil { | ||||||
| 		logger.Printf("GET %s", stripToken(endpoint)) | 		logger.Printf("GET %s", stripToken(endpoint)) | ||||||
| 		logger.Printf("token validation request failed: %s", err) | 		logger.Printf("token validation request failed: %s", result.Error()) | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	body, _ := ioutil.ReadAll(resp.Body) | 	logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body()) | ||||||
| 	resp.Body.Close() |  | ||||||
| 	logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) |  | ||||||
| 
 | 
 | ||||||
| 	if resp.StatusCode == 200 { | 	if result.StatusCode() == 200 { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) | 	logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()) | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -2,7 +2,6 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | @ -51,14 +50,11 @@ func (p *KeycloakProvider) SetGroup(group string) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||||
| 
 | 	json, err := requests.New(p.ValidateURL.String()). | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil) | 		WithContext(ctx). | ||||||
| 	req.Header.Set("Authorization", "Bearer "+s.AccessToken) | 		SetHeader("Authorization", "Bearer "+s.AccessToken). | ||||||
| 	if err != nil { | 		Do(). | ||||||
| 		logger.Printf("failed building request %s", err) | 		UnmarshalJSON() | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	json, err := requests.Request(req) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("failed making request %s", err) | 		logger.Printf("failed making request %s", err) | ||||||
| 		return "", err | 		return "", err | ||||||
|  |  | ||||||
|  | @ -58,13 +58,13 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
| 	} | 	} | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	req.Header = getLinkedInHeader(s.AccessToken) |  | ||||||
| 
 | 
 | ||||||
| 	json, err := requests.Request(req) | 	requestURL := p.ProfileURL.String() + "?format=json" | ||||||
|  | 	json, err := requests.New(requestURL). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(getLinkedInHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -15,6 +15,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/dgrijalva/jwt-go" | 	"github.com/dgrijalva/jwt-go" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||||
| 	"gopkg.in/square/go-jose.v2" | 	"gopkg.in/square/go-jose.v2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -128,51 +129,34 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) { | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) { | func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) { | ||||||
| 	// query the user info endpoint for user attributes
 |  | ||||||
| 	var req *http.Request |  | ||||||
| 	req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Authorization", "Bearer "+accessToken) |  | ||||||
| 
 |  | ||||||
| 	resp, err := http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	var body []byte |  | ||||||
| 	body, err = ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// parse the user attributes from the data we got and make sure that
 | 	// parse the user attributes from the data we got and make sure that
 | ||||||
| 	// the email address has been validated.
 | 	// the email address has been validated.
 | ||||||
| 	var emailData struct { | 	var emailData struct { | ||||||
| 		Email         string `json:"email"` | 		Email         string `json:"email"` | ||||||
| 		EmailVerified bool   `json:"email_verified"` | 		EmailVerified bool   `json:"email_verified"` | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &emailData) | 
 | ||||||
|  | 	// query the user info endpoint for user attributes
 | ||||||
|  | 	err := requests.New(userInfoEndpoint). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		SetHeader("Authorization", "Bearer "+accessToken). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&emailData) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return "", err | ||||||
| 	} | 	} | ||||||
| 	if emailData.Email == "" { | 
 | ||||||
| 		err = fmt.Errorf("missing email") | 	email := emailData.Email | ||||||
| 		return | 	if email == "" { | ||||||
|  | 		return "", fmt.Errorf("missing email") | ||||||
| 	} | 	} | ||||||
| 	email = emailData.Email | 
 | ||||||
| 	if !emailData.EmailVerified { | 	if !emailData.EmailVerified { | ||||||
| 		err = fmt.Errorf("email %s not listed as verified", email) | 		return "", fmt.Errorf("email %s not listed as verified", email) | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 	return | 
 | ||||||
|  | 	return email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
|  | @ -201,30 +185,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | ||||||
| 	params.Add("code", code) | 	params.Add("code", code) | ||||||
| 	params.Add("grant_type", "authorization_code") | 	params.Add("grant_type", "authorization_code") | ||||||
| 
 | 
 | ||||||
| 	var req *http.Request |  | ||||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |  | ||||||
| 
 |  | ||||||
| 	var resp *http.Response |  | ||||||
| 	resp, err = http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	var body []byte |  | ||||||
| 	body, err = ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Get the token from the body that we got from the token endpoint.
 | 	// Get the token from the body that we got from the token endpoint.
 | ||||||
| 	var jsonResponse struct { | 	var jsonResponse struct { | ||||||
| 		AccessToken string `json:"access_token"` | 		AccessToken string `json:"access_token"` | ||||||
|  | @ -232,9 +192,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | ||||||
| 		TokenType   string `json:"token_type"` | 		TokenType   string `json:"token_type"` | ||||||
| 		ExpiresIn   int64  `json:"expires_in"` | 		ExpiresIn   int64  `json:"expires_in"` | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &jsonResponse) | 	err = requests.New(p.RedeemURL.String()). | ||||||
|  | 		WithContext(ctx). | ||||||
|  | 		WithMethod("POST"). | ||||||
|  | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
|  | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalInto(&jsonResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// check nonce here
 | 	// check nonce here
 | ||||||
|  |  | ||||||
|  | @ -6,7 +6,6 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -31,18 +30,15 @@ func getNextcloudHeader(accessToken string) http.Header { | ||||||
| 
 | 
 | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", | 	json, err := requests.New(p.ValidateURL.String()). | ||||||
| 		p.ValidateURL.String(), nil) | 		WithContext(ctx). | ||||||
|  | 		WithHeaders(getNextcloudHeader(s.AccessToken)). | ||||||
|  | 		Do(). | ||||||
|  | 		UnmarshalJSON() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Printf("failed building request %s", err) | 		return "", fmt.Errorf("error making request: %v", err) | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
| 	req.Header = getNextcloudHeader(s.AccessToken) |  | ||||||
| 	json, err := requests.Request(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Printf("failed making request %s", err) |  | ||||||
| 		return "", err |  | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	email, err := json.Get("ocs").Get("data").Get("email").String() | 	email, err := json.Get("ocs").Get("data").Get("email").String() | ||||||
| 	return email, err | 	return email, err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -256,13 +256,11 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | ||||||
| 		// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
 | 		// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
 | ||||||
| 		// contents at the profileURL contains the email.
 | 		// contents at the profileURL contains the email.
 | ||||||
| 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | ||||||
| 		req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil) | 		respJSON, err := requests.New(profileURL). | ||||||
| 		if err != nil { | 			WithContext(ctx). | ||||||
| 			return nil, err | 			WithHeaders(getOIDCHeader(accessToken)). | ||||||
| 		} | 			Do(). | ||||||
| 		req.Header = getOIDCHeader(accessToken) | 			UnmarshalJSON() | ||||||
| 
 |  | ||||||
| 		respJSON, err := requests.Request(req) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | @ -3,17 +3,15 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"net/http" |  | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/coreos/go-oidc" | 	"github.com/coreos/go-oidc" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var _ Provider = (*ProviderData)(nil) | var _ Provider = (*ProviderData)(nil) | ||||||
|  | @ -39,35 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 		params.Add("resource", p.ProtectedResource.String()) | 		params.Add("resource", p.ProtectedResource.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var req *http.Request | 	result := requests.New(p.RedeemURL.String()). | ||||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) | 		WithContext(ctx). | ||||||
| 	if err != nil { | 		WithMethod("POST"). | ||||||
| 		return | 		WithBody(bytes.NewBufferString(params.Encode())). | ||||||
| 	} | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | 		Do() | ||||||
| 
 | 	if result.Error() != nil { | ||||||
| 	var resp *http.Response | 		return nil, result.Error() | ||||||
| 	resp, err = http.DefaultClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	var body []byte |  | ||||||
| 	body, err = ioutil.ReadAll(resp.Body) |  | ||||||
| 	resp.Body.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if resp.StatusCode != 200 { |  | ||||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// blindly try json and x-www-form-urlencoded
 | 	// blindly try json and x-www-form-urlencoded
 | ||||||
| 	var jsonResponse struct { | 	var jsonResponse struct { | ||||||
| 		AccessToken string `json:"access_token"` | 		AccessToken string `json:"access_token"` | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &jsonResponse) | 	err = result.UnmarshalInto(&jsonResponse) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		s = &sessions.SessionState{ | 		s = &sessions.SessionState{ | ||||||
| 			AccessToken: jsonResponse.AccessToken, | 			AccessToken: jsonResponse.AccessToken, | ||||||
|  | @ -76,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var v url.Values | 	var v url.Values | ||||||
| 	v, err = url.ParseQuery(string(body)) | 	v, err = url.ParseQuery(string(result.Body())) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | @ -84,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | ||||||
| 		created := time.Now() | 		created := time.Now() | ||||||
| 		s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} | 		s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} | ||||||
| 	} else { | 	} else { | ||||||
| 		err = fmt.Errorf("no access token found %s", body) | 		err = fmt.Errorf("no access token found %s", result.Body()) | ||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue