Merge pull request #660 from oauth2-proxy/request-builder
Use builder pattern to simplify requests to external endpoints
This commit is contained in:
		
						commit
						d29766609b
					
				|  | @ -8,6 +8,7 @@ | |||
| 
 | ||||
| ## Changes since v6.0.0 | ||||
| 
 | ||||
| - [#660](https://github.com/oauth2-proxy/oauth2-proxy/pull/660) Use builder pattern to simplify requests to external endpoints (@JoelSpeed) | ||||
| - [#591](https://github.com/oauth2-proxy/oauth2-proxy/pull/591) Introduce upstream package with new reverse proxy implementation (@JoelSpeed) | ||||
| - [#576](https://github.com/oauth2-proxy/oauth2-proxy/pull/576) Separate Cookie validation out of main options validation (@JoelSpeed) | ||||
| - [#656](https://github.com/oauth2-proxy/oauth2-proxy/pull/656) Split long session cookies more precisely (@NickMeves) | ||||
|  |  | |||
|  | @ -0,0 +1,118 @@ | |||
| package requests | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| ) | ||||
| 
 | ||||
| // Builder allows users to construct a request and then execute the
 | ||||
| // request via Do().
 | ||||
| // Do returns a Result which allows the user to get the body,
 | ||||
| // unmarshal the body into an interface, or into a simplejson.Json.
 | ||||
| type Builder interface { | ||||
| 	WithContext(context.Context) Builder | ||||
| 	WithBody(io.Reader) Builder | ||||
| 	WithMethod(string) Builder | ||||
| 	WithHeaders(http.Header) Builder | ||||
| 	SetHeader(key, value string) Builder | ||||
| 	Do() Result | ||||
| } | ||||
| 
 | ||||
| type builder struct { | ||||
| 	context  context.Context | ||||
| 	method   string | ||||
| 	endpoint string | ||||
| 	body     io.Reader | ||||
| 	header   http.Header | ||||
| 	result   *result | ||||
| } | ||||
| 
 | ||||
| // New provides a new Builder for the given endpoint.
 | ||||
| func New(endpoint string) Builder { | ||||
| 	return &builder{ | ||||
| 		endpoint: endpoint, | ||||
| 		method:   "GET", | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // WithContext adds a context to the request.
 | ||||
| // If no context is provided, context.Background() is used instead.
 | ||||
| func (r *builder) WithContext(ctx context.Context) Builder { | ||||
| 	r.context = ctx | ||||
| 	return r | ||||
| } | ||||
| 
 | ||||
| // WithBody adds a body to the request.
 | ||||
| func (r *builder) WithBody(body io.Reader) Builder { | ||||
| 	r.body = body | ||||
| 	return r | ||||
| } | ||||
| 
 | ||||
| // WithMethod sets the request method. Defaults to "GET".
 | ||||
| func (r *builder) WithMethod(method string) Builder { | ||||
| 	r.method = method | ||||
| 	return r | ||||
| } | ||||
| 
 | ||||
| // WithHeaders replaces the request header map with the given header map.
 | ||||
| func (r *builder) WithHeaders(header http.Header) Builder { | ||||
| 	r.header = header | ||||
| 	return r | ||||
| } | ||||
| 
 | ||||
| // SetHeader sets a single header to the given value.
 | ||||
| // May be used to add multiple headers.
 | ||||
| func (r *builder) SetHeader(key, value string) Builder { | ||||
| 	if r.header == nil { | ||||
| 		r.header = make(http.Header) | ||||
| 	} | ||||
| 	r.header.Set(key, value) | ||||
| 	return r | ||||
| } | ||||
| 
 | ||||
| // Do performs the request and returns the response in its raw form.
 | ||||
| // If the request has already been performed, returns the previous result.
 | ||||
| // This will not allow you to repeat a request.
 | ||||
| func (r *builder) Do() Result { | ||||
| 	if r.result != nil { | ||||
| 		// Request has already been done
 | ||||
| 		return r.result | ||||
| 	} | ||||
| 
 | ||||
| 	// Must provide a non-nil context to NewRequestWithContext
 | ||||
| 	if r.context == nil { | ||||
| 		r.context = context.Background() | ||||
| 	} | ||||
| 
 | ||||
| 	return r.do() | ||||
| } | ||||
| 
 | ||||
| // do creates the request, executes it with the default client and extracts the
 | ||||
| // the body into the response
 | ||||
| func (r *builder) do() Result { | ||||
| 	req, err := http.NewRequestWithContext(r.context, r.method, r.endpoint, r.body) | ||||
| 	if err != nil { | ||||
| 		r.result = &result{err: fmt.Errorf("error creating request: %v", err)} | ||||
| 		return r.result | ||||
| 	} | ||||
| 	req.Header = r.header | ||||
| 
 | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		r.result = &result{err: fmt.Errorf("error performing request: %v", err)} | ||||
| 		return r.result | ||||
| 	} | ||||
| 
 | ||||
| 	defer resp.Body.Close() | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		r.result = &result{err: fmt.Errorf("error reading response body: %v", err)} | ||||
| 		return r.result | ||||
| 	} | ||||
| 
 | ||||
| 	r.result = &result{response: resp, body: body} | ||||
| 	return r.result | ||||
| } | ||||
|  | @ -0,0 +1,376 @@ | |||
| package requests | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/bitly/go-simplejson" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Builder suite", func() { | ||||
| 	var b Builder | ||||
| 	getBuilder := func() Builder { return b } | ||||
| 
 | ||||
| 	baseHeaders := http.Header{ | ||||
| 		"Accept-Encoding": []string{"gzip"}, | ||||
| 		"User-Agent":      []string{"Go-http-client/1.1"}, | ||||
| 	} | ||||
| 
 | ||||
| 	BeforeEach(func() { | ||||
| 		// Most tests will request the server address
 | ||||
| 		b = New(serverAddr + "/json/path") | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("with a basic request", func() { | ||||
| 		assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 			Method:     "GET", | ||||
| 			Header:     baseHeaders, | ||||
| 			Body:       []byte{}, | ||||
| 			RequestURI: "/json/path", | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("with a context", func() { | ||||
| 		var ctx context.Context | ||||
| 		var cancel context.CancelFunc | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			ctx, cancel = context.WithCancel(context.Background()) | ||||
| 			b = b.WithContext(ctx) | ||||
| 		}) | ||||
| 
 | ||||
| 		AfterEach(func() { | ||||
| 			cancel() | ||||
| 		}) | ||||
| 
 | ||||
| 		assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 			Method:     "GET", | ||||
| 			Header:     baseHeaders, | ||||
| 			Body:       []byte{}, | ||||
| 			RequestURI: "/json/path", | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("if the context is cancelled", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				cancel() | ||||
| 			}) | ||||
| 
 | ||||
| 			assertRequestError(getBuilder, "context canceled") | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("with a body", func() { | ||||
| 		const body = "{\"some\": \"body\"}" | ||||
| 		header := baseHeaders.Clone() | ||||
| 		header.Set("Content-Length", fmt.Sprintf("%d", len(body))) | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			buf := bytes.NewBuffer([]byte(body)) | ||||
| 			b = b.WithBody(buf) | ||||
| 		}) | ||||
| 
 | ||||
| 		assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 			Method:     "GET", | ||||
| 			Header:     header, | ||||
| 			Body:       []byte(body), | ||||
| 			RequestURI: "/json/path", | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("with a method", func() { | ||||
| 		Context("POST with a body", func() { | ||||
| 			const body = "{\"some\": \"body\"}" | ||||
| 			header := baseHeaders.Clone() | ||||
| 			header.Set("Content-Length", fmt.Sprintf("%d", len(body))) | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				buf := bytes.NewBuffer([]byte(body)) | ||||
| 				b = b.WithMethod("POST").WithBody(buf) | ||||
| 			}) | ||||
| 
 | ||||
| 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 				Method:     "POST", | ||||
| 				Header:     header, | ||||
| 				Body:       []byte(body), | ||||
| 				RequestURI: "/json/path", | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("POST without a body", func() { | ||||
| 			header := baseHeaders.Clone() | ||||
| 			header.Set("Content-Length", "0") | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				b = b.WithMethod("POST") | ||||
| 			}) | ||||
| 
 | ||||
| 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 				Method:     "POST", | ||||
| 				Header:     header, | ||||
| 				Body:       []byte{}, | ||||
| 				RequestURI: "/json/path", | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("OPTIONS", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				b = b.WithMethod("OPTIONS") | ||||
| 			}) | ||||
| 
 | ||||
| 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 				Method:     "OPTIONS", | ||||
| 				Header:     baseHeaders, | ||||
| 				Body:       []byte{}, | ||||
| 				RequestURI: "/json/path", | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("INVALID-\\t-METHOD", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				b = b.WithMethod("INVALID-\t-METHOD") | ||||
| 			}) | ||||
| 
 | ||||
| 			assertRequestError(getBuilder, "error creating request: net/http: invalid method \"INVALID-\\t-METHOD\"") | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("with headers", func() { | ||||
| 		Context("setting a header", func() { | ||||
| 			header := baseHeaders.Clone() | ||||
| 			header.Set("header", "value") | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				b = b.SetHeader("header", "value") | ||||
| 			}) | ||||
| 
 | ||||
| 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 				Method:     "GET", | ||||
| 				Header:     header, | ||||
| 				Body:       []byte{}, | ||||
| 				RequestURI: "/json/path", | ||||
| 			}) | ||||
| 
 | ||||
| 			Context("then replacing the headers", func() { | ||||
| 				replacementHeaders := http.Header{ | ||||
| 					"Accept-Encoding": []string{"*"}, | ||||
| 					"User-Agent":      []string{"test-agent"}, | ||||
| 					"Foo":             []string{"bar, baz"}, | ||||
| 				} | ||||
| 
 | ||||
| 				BeforeEach(func() { | ||||
| 					b = b.WithHeaders(replacementHeaders) | ||||
| 				}) | ||||
| 
 | ||||
| 				assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 					Method:     "GET", | ||||
| 					Header:     replacementHeaders, | ||||
| 					Body:       []byte{}, | ||||
| 					RequestURI: "/json/path", | ||||
| 				}) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("replacing the header", func() { | ||||
| 			replacementHeaders := http.Header{ | ||||
| 				"Accept-Encoding": []string{"*"}, | ||||
| 				"User-Agent":      []string{"test-agent"}, | ||||
| 				"Foo":             []string{"bar, baz"}, | ||||
| 			} | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				b = b.WithHeaders(replacementHeaders) | ||||
| 			}) | ||||
| 
 | ||||
| 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 				Method:     "GET", | ||||
| 				Header:     replacementHeaders, | ||||
| 				Body:       []byte{}, | ||||
| 				RequestURI: "/json/path", | ||||
| 			}) | ||||
| 
 | ||||
| 			Context("then setting a header", func() { | ||||
| 				header := replacementHeaders.Clone() | ||||
| 				header.Set("User-Agent", "different-agent") | ||||
| 
 | ||||
| 				BeforeEach(func() { | ||||
| 					b = b.SetHeader("User-Agent", "different-agent") | ||||
| 				}) | ||||
| 
 | ||||
| 				assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 					Method:     "GET", | ||||
| 					Header:     header, | ||||
| 					Body:       []byte{}, | ||||
| 					RequestURI: "/json/path", | ||||
| 				}) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("if the request has been completed and then modified", func() { | ||||
| 		BeforeEach(func() { | ||||
| 			result := b.Do() | ||||
| 			Expect(result.Error()).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			b.WithMethod("POST") | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("should not redo the request", func() { | ||||
| 			assertSuccessfulRequest(getBuilder, testHTTPRequest{ | ||||
| 				Method:     "GET", | ||||
| 				Header:     baseHeaders, | ||||
| 				Body:       []byte{}, | ||||
| 				RequestURI: "/json/path", | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("when the requested page is not found", func() { | ||||
| 		BeforeEach(func() { | ||||
| 			b = New(serverAddr + "/not-found") | ||||
| 		}) | ||||
| 
 | ||||
| 		assertJSONError(getBuilder, "404 page not found") | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("when the requested page is not valid JSON", func() { | ||||
| 		BeforeEach(func() { | ||||
| 			b = New(serverAddr + "/string/path") | ||||
| 		}) | ||||
| 
 | ||||
| 		assertJSONError(getBuilder, "invalid character 'O' looking for beginning of value") | ||||
| 	}) | ||||
| }) | ||||
| 
 | ||||
| func assertSuccessfulRequest(builder func() Builder, expectedRequest testHTTPRequest) { | ||||
| 	Context("Do", func() { | ||||
| 		var result Result | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			result = builder().Do() | ||||
| 			Expect(result.Error()).ToNot(HaveOccurred()) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("returns a successful status", func() { | ||||
| 			Expect(result.StatusCode()).To(Equal(http.StatusOK)) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("made the expected request", func() { | ||||
| 			actualRequest := testHTTPRequest{} | ||||
| 			Expect(json.Unmarshal(result.Body(), &actualRequest)).To(Succeed()) | ||||
| 
 | ||||
| 			Expect(actualRequest).To(Equal(expectedRequest)) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UnmarshalInto", func() { | ||||
| 		var actualRequest testHTTPRequest | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			Expect(builder().Do().UnmarshalInto(&actualRequest)).To(Succeed()) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("made the expected request", func() { | ||||
| 			Expect(actualRequest).To(Equal(expectedRequest)) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UnmarshalJSON", func() { | ||||
| 		var response *simplejson.Json | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			var err error | ||||
| 			response, err = builder().Do().UnmarshalJSON() | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("made the expected reqest", func() { | ||||
| 			header := http.Header{} | ||||
| 			for key, value := range response.Get("Header").MustMap() { | ||||
| 				vs, ok := value.([]interface{}) | ||||
| 				Expect(ok).To(BeTrue()) | ||||
| 				svs := []string{} | ||||
| 				for _, v := range vs { | ||||
| 					sv, ok := v.(string) | ||||
| 					Expect(ok).To(BeTrue()) | ||||
| 					svs = append(svs, sv) | ||||
| 				} | ||||
| 				header[key] = svs | ||||
| 			} | ||||
| 
 | ||||
| 			// Other json unmarhsallers base64 decode byte slices automatically
 | ||||
| 			body, err := base64.StdEncoding.DecodeString(response.Get("Body").MustString()) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			actualRequest := testHTTPRequest{ | ||||
| 				Method:     response.Get("Method").MustString(), | ||||
| 				Header:     header, | ||||
| 				Body:       body, | ||||
| 				RequestURI: response.Get("RequestURI").MustString(), | ||||
| 			} | ||||
| 
 | ||||
| 			Expect(actualRequest).To(Equal(expectedRequest)) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func assertRequestError(builder func() Builder, errorMessage string) { | ||||
| 	Context("Do", func() { | ||||
| 		It("returns an error", func() { | ||||
| 			result := builder().Do() | ||||
| 			Expect(result.Error()).To(MatchError(ContainSubstring(errorMessage))) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UnmarshalInto", func() { | ||||
| 		It("returns an error", func() { | ||||
| 			var actualRequest testHTTPRequest | ||||
| 			err := builder().Do().UnmarshalInto(&actualRequest) | ||||
| 			Expect(err).To(MatchError(ContainSubstring(errorMessage))) | ||||
| 
 | ||||
| 			// Should be empty
 | ||||
| 			Expect(actualRequest).To(Equal(testHTTPRequest{})) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UnmarshalJSON", func() { | ||||
| 		It("returns an error", func() { | ||||
| 			resp, err := builder().Do().UnmarshalJSON() | ||||
| 			Expect(err).To(MatchError(ContainSubstring(errorMessage))) | ||||
| 			Expect(resp).To(BeNil()) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func assertJSONError(builder func() Builder, errorMessage string) { | ||||
| 	Context("Do", func() { | ||||
| 		It("does not return an error", func() { | ||||
| 			result := builder().Do() | ||||
| 			Expect(result.Error()).To(BeNil()) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UnmarshalInto", func() { | ||||
| 		It("returns an error", func() { | ||||
| 			var actualRequest testHTTPRequest | ||||
| 			err := builder().Do().UnmarshalInto(&actualRequest) | ||||
| 			Expect(err).To(MatchError(ContainSubstring(errorMessage))) | ||||
| 
 | ||||
| 			// Should be empty
 | ||||
| 			Expect(actualRequest).To(Equal(testHTTPRequest{})) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UnmarshalJSON", func() { | ||||
| 		It("returns an error", func() { | ||||
| 			resp, err := builder().Do().UnmarshalJSON() | ||||
| 			Expect(err).To(MatchError(ContainSubstring(errorMessage))) | ||||
| 			Expect(resp).To(BeNil()) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
|  | @ -1,74 +0,0 @@ | |||
| package requests | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/bitly/go-simplejson" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
| // Request parses the request body into a simplejson.Json object
 | ||||
| func Request(req *http.Request) (*simplejson.Json, error) { | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("%s %s %s", req.Method, req.URL, err) | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	if body != nil { | ||||
| 		defer resp.Body.Close() | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("problem reading http request body: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return nil, fmt.Errorf("got %d %s", resp.StatusCode, body) | ||||
| 	} | ||||
| 
 | ||||
| 	data, err := simplejson.NewJson(body) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error unmarshalling json: %w", err) | ||||
| 	} | ||||
| 	return data, nil | ||||
| } | ||||
| 
 | ||||
| // RequestJSON parses the request body into the given interface
 | ||||
| func RequestJSON(req *http.Request, v interface{}) error { | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("%s %s %s", req.Method, req.URL, err) | ||||
| 		return err | ||||
| 	} | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	if body != nil { | ||||
| 		defer resp.Body.Close() | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error reading body from http response: %w", err) | ||||
| 	} | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return fmt.Errorf("got %d %s", resp.StatusCode, body) | ||||
| 	} | ||||
| 	return json.Unmarshal(body, v) | ||||
| } | ||||
| 
 | ||||
| // RequestUnparsedResponse performs a GET and returns the raw response object
 | ||||
| func RequestUnparsedResponse(ctx context.Context, url string, header http.Header) (resp *http.Response, err error) { | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", url, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error performing get request: %w", err) | ||||
| 	} | ||||
| 	req.Header = header | ||||
| 
 | ||||
| 	return http.DefaultClient.Do(req) | ||||
| } | ||||
|  | @ -0,0 +1,96 @@ | |||
| package requests | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	server     *httptest.Server | ||||
| 	serverAddr string | ||||
| ) | ||||
| 
 | ||||
| func TestRequetsSuite(t *testing.T) { | ||||
| 	logger.SetOutput(GinkgoWriter) | ||||
| 	log.SetOutput(GinkgoWriter) | ||||
| 
 | ||||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "Requests Suite") | ||||
| } | ||||
| 
 | ||||
| var _ = BeforeSuite(func() { | ||||
| 	// Set up a webserver that reflects requests
 | ||||
| 	mux := http.NewServeMux() | ||||
| 	mux.Handle("/json/", &testHTTPUpstream{}) | ||||
| 	mux.HandleFunc("/string/", func(rw http.ResponseWriter, _ *http.Request) { | ||||
| 		rw.Write([]byte("OK")) | ||||
| 	}) | ||||
| 	server = httptest.NewServer(mux) | ||||
| 	serverAddr = fmt.Sprintf("http://%s", server.Listener.Addr().String()) | ||||
| }) | ||||
| 
 | ||||
| var _ = AfterSuite(func() { | ||||
| 	server.Close() | ||||
| }) | ||||
| 
 | ||||
| // testHTTPRequest is a struct used to capture the state of a request made to
 | ||||
| // the test server
 | ||||
| type testHTTPRequest struct { | ||||
| 	Method     string | ||||
| 	Header     http.Header | ||||
| 	Body       []byte | ||||
| 	RequestURI string | ||||
| } | ||||
| 
 | ||||
| type testHTTPUpstream struct{} | ||||
| 
 | ||||
| func (t *testHTTPUpstream) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||
| 	request, err := toTestHTTPRequest(req) | ||||
| 	if err != nil { | ||||
| 		t.writeError(rw, err) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	data, err := json.Marshal(request) | ||||
| 	if err != nil { | ||||
| 		t.writeError(rw, err) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	rw.Header().Set("Content-Type", "application/json") | ||||
| 	rw.Write(data) | ||||
| } | ||||
| 
 | ||||
| func (t *testHTTPUpstream) writeError(rw http.ResponseWriter, err error) { | ||||
| 	rw.WriteHeader(500) | ||||
| 	if err != nil { | ||||
| 		rw.Write([]byte(err.Error())) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func toTestHTTPRequest(req *http.Request) (testHTTPRequest, error) { | ||||
| 	requestBody := []byte{} | ||||
| 	if req.Body != http.NoBody { | ||||
| 		var err error | ||||
| 		requestBody, err = ioutil.ReadAll(req.Body) | ||||
| 		if err != nil { | ||||
| 			return testHTTPRequest{}, err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return testHTTPRequest{ | ||||
| 		Method:     req.Method, | ||||
| 		Header:     req.Header, | ||||
| 		Body:       requestBody, | ||||
| 		RequestURI: req.RequestURI, | ||||
| 	}, nil | ||||
| } | ||||
|  | @ -1,136 +0,0 @@ | |||
| package requests | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/bitly/go-simplejson" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func testBackend(t *testing.T, responseCode int, payload string) *httptest.Server { | ||||
| 	return httptest.NewServer(http.HandlerFunc( | ||||
| 		func(w http.ResponseWriter, r *http.Request) { | ||||
| 			w.WriteHeader(responseCode) | ||||
| 			_, err := w.Write([]byte(payload)) | ||||
| 			require.NoError(t, err) | ||||
| 		})) | ||||
| } | ||||
| 
 | ||||
| func TestRequest(t *testing.T) { | ||||
| 	backend := testBackend(t, 200, "{\"foo\": \"bar\"}") | ||||
| 	defer backend.Close() | ||||
| 
 | ||||
| 	req, _ := http.NewRequest("GET", backend.URL, nil) | ||||
| 	response, err := Request(req) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	result, err := response.Get("foo").String() | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "bar", result) | ||||
| } | ||||
| 
 | ||||
| func TestRequestFailure(t *testing.T) { | ||||
| 	// Create a backend to generate a test URL, then close it to cause a
 | ||||
| 	// connection error.
 | ||||
| 	backend := testBackend(t, 200, "{\"foo\": \"bar\"}") | ||||
| 	backend.Close() | ||||
| 
 | ||||
| 	req, err := http.NewRequest("GET", backend.URL, nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	resp, err := Request(req) | ||||
| 	assert.Equal(t, (*simplejson.Json)(nil), resp) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	if !strings.Contains(err.Error(), "refused") { | ||||
| 		t.Error("expected error when a connection fails: ", err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestHttpErrorCode(t *testing.T) { | ||||
| 	backend := testBackend(t, 404, "{\"foo\": \"bar\"}") | ||||
| 	defer backend.Close() | ||||
| 
 | ||||
| 	req, err := http.NewRequest("GET", backend.URL, nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	resp, err := Request(req) | ||||
| 	assert.Equal(t, (*simplejson.Json)(nil), resp) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| } | ||||
| 
 | ||||
| func TestJsonParsingError(t *testing.T) { | ||||
| 	backend := testBackend(t, 200, "not well-formed JSON") | ||||
| 	defer backend.Close() | ||||
| 
 | ||||
| 	req, err := http.NewRequest("GET", backend.URL, nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	resp, err := Request(req) | ||||
| 	assert.Equal(t, (*simplejson.Json)(nil), resp) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| } | ||||
| 
 | ||||
| // Parsing a URL practically never fails, so we won't cover that test case.
 | ||||
| func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { | ||||
| 	backend := httptest.NewServer(http.HandlerFunc( | ||||
| 		func(w http.ResponseWriter, r *http.Request) { | ||||
| 			token := r.FormValue("access_token") | ||||
| 			if r.URL.Path == "/" && token == "my_token" { | ||||
| 				w.WriteHeader(200) | ||||
| 				_, err := w.Write([]byte("some payload")) | ||||
| 				require.NoError(t, err) | ||||
| 			} else { | ||||
| 				w.WriteHeader(403) | ||||
| 			} | ||||
| 		})) | ||||
| 	defer backend.Close() | ||||
| 
 | ||||
| 	response, err := RequestUnparsedResponse( | ||||
| 		context.Background(), backend.URL+"?access_token=my_token", nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	defer response.Body.Close() | ||||
| 
 | ||||
| 	assert.Equal(t, 200, response.StatusCode) | ||||
| 	body, err := ioutil.ReadAll(response.Body) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "some payload", string(body)) | ||||
| } | ||||
| 
 | ||||
| func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) { | ||||
| 	backend := testBackend(t, 200, "some payload") | ||||
| 	// Close the backend now to force a request failure.
 | ||||
| 	backend.Close() | ||||
| 
 | ||||
| 	response, err := RequestUnparsedResponse( | ||||
| 		context.Background(), backend.URL+"?access_token=my_token", nil) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, (*http.Response)(nil), response) | ||||
| } | ||||
| 
 | ||||
| func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { | ||||
| 	backend := httptest.NewServer(http.HandlerFunc( | ||||
| 		func(w http.ResponseWriter, r *http.Request) { | ||||
| 			if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" { | ||||
| 				w.WriteHeader(200) | ||||
| 				_, err := w.Write([]byte("some payload")) | ||||
| 				require.NoError(t, err) | ||||
| 			} else { | ||||
| 				w.WriteHeader(403) | ||||
| 			} | ||||
| 		})) | ||||
| 	defer backend.Close() | ||||
| 
 | ||||
| 	headers := make(http.Header) | ||||
| 	headers.Set("Auth", "my_token") | ||||
| 	response, err := RequestUnparsedResponse(context.Background(), backend.URL, headers) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	defer response.Body.Close() | ||||
| 
 | ||||
| 	assert.Equal(t, 200, response.StatusCode) | ||||
| 	body, err := ioutil.ReadAll(response.Body) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	assert.Equal(t, "some payload", string(body)) | ||||
| } | ||||
|  | @ -0,0 +1,98 @@ | |||
| package requests | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/bitly/go-simplejson" | ||||
| ) | ||||
| 
 | ||||
| // Result is the result of a request created by a Builder
 | ||||
| type Result interface { | ||||
| 	Error() error | ||||
| 	StatusCode() int | ||||
| 	Headers() http.Header | ||||
| 	Body() []byte | ||||
| 	UnmarshalInto(interface{}) error | ||||
| 	UnmarshalJSON() (*simplejson.Json, error) | ||||
| } | ||||
| 
 | ||||
| type result struct { | ||||
| 	err      error | ||||
| 	response *http.Response | ||||
| 	body     []byte | ||||
| } | ||||
| 
 | ||||
| // Error returns an error from the result if present
 | ||||
| func (r *result) Error() error { | ||||
| 	return r.err | ||||
| } | ||||
| 
 | ||||
| // StatusCode returns the response's status code
 | ||||
| func (r *result) StatusCode() int { | ||||
| 	if r.response != nil { | ||||
| 		return r.response.StatusCode | ||||
| 	} | ||||
| 	return 0 | ||||
| } | ||||
| 
 | ||||
| // Headers returns the response's headers
 | ||||
| func (r *result) Headers() http.Header { | ||||
| 	if r.response != nil { | ||||
| 		return r.response.Header | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Body returns the response's body
 | ||||
| func (r *result) Body() []byte { | ||||
| 	return r.body | ||||
| } | ||||
| 
 | ||||
| // UnmarshalInto attempts to unmarshal the response into the the given interface.
 | ||||
| // The response body is assumed to be JSON.
 | ||||
| // The response must have a 200 status otherwise an error will be returned.
 | ||||
| func (r *result) UnmarshalInto(into interface{}) error { | ||||
| 	body, err := r.getBodyForUnmarshal() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if err := json.Unmarshal(body, into); err != nil { | ||||
| 		return fmt.Errorf("error unmarshalling body: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // UnmarshalJSON performs the request and attempts to unmarshal the response into a
 | ||||
| // simplejson.Json. The response body is assume to be JSON.
 | ||||
| // The response must have a 200 status otherwise an error will be returned.
 | ||||
| func (r *result) UnmarshalJSON() (*simplejson.Json, error) { | ||||
| 	body, err := r.getBodyForUnmarshal() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	data, err := simplejson.NewJson(body) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error reading json: %v", err) | ||||
| 	} | ||||
| 	return data, nil | ||||
| } | ||||
| 
 | ||||
| // getBodyForUnmarshal returns the body if there wasn't an error and the status
 | ||||
| // code was 200.
 | ||||
| func (r *result) getBodyForUnmarshal() ([]byte, error) { | ||||
| 	if r.Error() != nil { | ||||
| 		return nil, r.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	// Only unmarshal body if the response was successful
 | ||||
| 	if r.StatusCode() != http.StatusOK { | ||||
| 		return nil, fmt.Errorf("unexpected status \"%d\": %s", r.StatusCode(), r.Body()) | ||||
| 	} | ||||
| 
 | ||||
| 	return r.Body(), nil | ||||
| } | ||||
|  | @ -0,0 +1,326 @@ | |||
| package requests | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Result suite", func() { | ||||
| 	Context("with a result", func() { | ||||
| 		type resultTableInput struct { | ||||
| 			result             Result | ||||
| 			expectedError      error | ||||
| 			expectedStatusCode int | ||||
| 			expectedHeaders    http.Header | ||||
| 			expectedBody       []byte | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("accessors should return expected results", | ||||
| 			func(in resultTableInput) { | ||||
| 				if in.expectedError != nil { | ||||
| 					Expect(in.result.Error()).To(MatchError(in.expectedError)) | ||||
| 				} else { | ||||
| 					Expect(in.result.Error()).To(BeNil()) | ||||
| 				} | ||||
| 
 | ||||
| 				Expect(in.result.StatusCode()).To(Equal(in.expectedStatusCode)) | ||||
| 				Expect(in.result.Headers()).To(Equal(in.expectedHeaders)) | ||||
| 				Expect(in.result.Body()).To(Equal(in.expectedBody)) | ||||
| 			}, | ||||
| 			Entry("with an empty result", resultTableInput{ | ||||
| 				result:             &result{}, | ||||
| 				expectedError:      nil, | ||||
| 				expectedStatusCode: 0, | ||||
| 				expectedHeaders:    nil, | ||||
| 				expectedBody:       nil, | ||||
| 			}), | ||||
| 			Entry("with an error", resultTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: errors.New("error"), | ||||
| 				}, | ||||
| 				expectedError:      errors.New("error"), | ||||
| 				expectedStatusCode: 0, | ||||
| 				expectedHeaders:    nil, | ||||
| 				expectedBody:       nil, | ||||
| 			}), | ||||
| 			Entry("with a response with no headers", resultTableInput{ | ||||
| 				result: &result{ | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusTeapot, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedError:      nil, | ||||
| 				expectedStatusCode: http.StatusTeapot, | ||||
| 				expectedHeaders:    nil, | ||||
| 				expectedBody:       nil, | ||||
| 			}), | ||||
| 			Entry("with a response with no status code", resultTableInput{ | ||||
| 				result: &result{ | ||||
| 					response: &http.Response{ | ||||
| 						Header: http.Header{ | ||||
| 							"foo": []string{"bar"}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedError:      nil, | ||||
| 				expectedStatusCode: 0, | ||||
| 				expectedHeaders: http.Header{ | ||||
| 					"foo": []string{"bar"}, | ||||
| 				}, | ||||
| 				expectedBody: nil, | ||||
| 			}), | ||||
| 			Entry("with a response with a body", resultTableInput{ | ||||
| 				result: &result{ | ||||
| 					body: []byte("some body"), | ||||
| 				}, | ||||
| 				expectedError:      nil, | ||||
| 				expectedStatusCode: 0, | ||||
| 				expectedHeaders:    nil, | ||||
| 				expectedBody:       []byte("some body"), | ||||
| 			}), | ||||
| 			Entry("with all fields", resultTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: errors.New("some error"), | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusFound, | ||||
| 						Header: http.Header{ | ||||
| 							"header": []string{"value"}, | ||||
| 						}, | ||||
| 					}, | ||||
| 					body: []byte("a body"), | ||||
| 				}, | ||||
| 				expectedError:      errors.New("some error"), | ||||
| 				expectedStatusCode: http.StatusFound, | ||||
| 				expectedHeaders: http.Header{ | ||||
| 					"header": []string{"value"}, | ||||
| 				}, | ||||
| 				expectedBody: []byte("a body"), | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UnmarshalInto", func() { | ||||
| 		type testStruct struct { | ||||
| 			A string `json:"a"` | ||||
| 			B int    `json:"b"` | ||||
| 		} | ||||
| 
 | ||||
| 		type unmarshalIntoTableInput struct { | ||||
| 			result         Result | ||||
| 			expectedErr    error | ||||
| 			expectedOutput *testStruct | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("with a result", | ||||
| 			func(in unmarshalIntoTableInput) { | ||||
| 				input := &testStruct{} | ||||
| 				err := in.result.UnmarshalInto(input) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(input).To(Equal(in.expectedOutput)) | ||||
| 			}, | ||||
| 			Entry("with an error", unmarshalIntoTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: errors.New("got an error"), | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte("{\"a\": \"foo\"}"), | ||||
| 				}, | ||||
| 				expectedErr:    errors.New("got an error"), | ||||
| 				expectedOutput: &testStruct{}, | ||||
| 			}), | ||||
| 			Entry("with a 409 status code", unmarshalIntoTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusConflict, | ||||
| 					}, | ||||
| 					body: []byte("{\"a\": \"foo\"}"), | ||||
| 				}, | ||||
| 				expectedErr:    errors.New("unexpected status \"409\": {\"a\": \"foo\"}"), | ||||
| 				expectedOutput: &testStruct{}, | ||||
| 			}), | ||||
| 			Entry("when the response has a valid json response", unmarshalIntoTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte("{\"a\": \"foo\", \"b\": 1}"), | ||||
| 				}, | ||||
| 				expectedErr:    nil, | ||||
| 				expectedOutput: &testStruct{A: "foo", B: 1}, | ||||
| 			}), | ||||
| 			Entry("when the response body is empty", unmarshalIntoTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte(""), | ||||
| 				}, | ||||
| 				expectedErr:    errors.New("error unmarshalling body: unexpected end of JSON input"), | ||||
| 				expectedOutput: &testStruct{}, | ||||
| 			}), | ||||
| 			Entry("when the response body is not json", unmarshalIntoTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte("not json"), | ||||
| 				}, | ||||
| 				expectedErr:    errors.New("error unmarshalling body: invalid character 'o' in literal null (expecting 'u')"), | ||||
| 				expectedOutput: &testStruct{}, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("UnmarshalJSON", func() { | ||||
| 		type testStruct struct { | ||||
| 			A string `json:"a"` | ||||
| 			B int    `json:"b"` | ||||
| 		} | ||||
| 
 | ||||
| 		type unmarshalJSONTableInput struct { | ||||
| 			result         Result | ||||
| 			expectedErr    error | ||||
| 			expectedOutput *testStruct | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("with a result", | ||||
| 			func(in unmarshalJSONTableInput) { | ||||
| 				j, err := in.result.UnmarshalJSON() | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 					Expect(j).To(BeNil()) | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 				// No error so j should not be nil
 | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				input := &testStruct{ | ||||
| 					A: j.Get("a").MustString(), | ||||
| 					B: j.Get("b").MustInt(), | ||||
| 				} | ||||
| 				Expect(input).To(Equal(in.expectedOutput)) | ||||
| 			}, | ||||
| 			Entry("with an error", unmarshalJSONTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: errors.New("got an error"), | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte("{\"a\": \"foo\"}"), | ||||
| 				}, | ||||
| 				expectedErr:    errors.New("got an error"), | ||||
| 				expectedOutput: &testStruct{}, | ||||
| 			}), | ||||
| 			Entry("with a 409 status code", unmarshalJSONTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusConflict, | ||||
| 					}, | ||||
| 					body: []byte("{\"a\": \"foo\"}"), | ||||
| 				}, | ||||
| 				expectedErr:    errors.New("unexpected status \"409\": {\"a\": \"foo\"}"), | ||||
| 				expectedOutput: &testStruct{}, | ||||
| 			}), | ||||
| 			Entry("when the response has a valid json response", unmarshalJSONTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte("{\"a\": \"foo\", \"b\": 1}"), | ||||
| 				}, | ||||
| 				expectedErr:    nil, | ||||
| 				expectedOutput: &testStruct{A: "foo", B: 1}, | ||||
| 			}), | ||||
| 			Entry("when the response body is empty", unmarshalJSONTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte(""), | ||||
| 				}, | ||||
| 				expectedErr:    errors.New("error reading json: EOF"), | ||||
| 				expectedOutput: &testStruct{}, | ||||
| 			}), | ||||
| 			Entry("when the response body is not json", unmarshalJSONTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte("not json"), | ||||
| 				}, | ||||
| 				expectedErr:    errors.New("error reading json: invalid character 'o' in literal null (expecting 'u')"), | ||||
| 				expectedOutput: &testStruct{}, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("getBodyForUnmarshal", func() { | ||||
| 		type getBodyForUnmarshalTableInput struct { | ||||
| 			result       *result | ||||
| 			expectedErr  error | ||||
| 			expectedBody []byte | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("when getting the body", func(in getBodyForUnmarshalTableInput) { | ||||
| 			body, err := in.result.getBodyForUnmarshal() | ||||
| 			if in.expectedErr != nil { | ||||
| 				Expect(err).To(MatchError(in.expectedErr)) | ||||
| 			} else { | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 			} | ||||
| 			Expect(body).To(Equal(in.expectedBody)) | ||||
| 		}, | ||||
| 			Entry("when the result has an error", getBodyForUnmarshalTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: errors.New("got an error"), | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte("body"), | ||||
| 				}, | ||||
| 				expectedErr:  errors.New("got an error"), | ||||
| 				expectedBody: nil, | ||||
| 			}), | ||||
| 			Entry("when the response has a 409 status code", getBodyForUnmarshalTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusConflict, | ||||
| 					}, | ||||
| 					body: []byte("body"), | ||||
| 				}, | ||||
| 				expectedErr:  errors.New("unexpected status \"409\": body"), | ||||
| 				expectedBody: nil, | ||||
| 			}), | ||||
| 			Entry("when the response has a 200 status code", getBodyForUnmarshalTableInput{ | ||||
| 				result: &result{ | ||||
| 					err: nil, | ||||
| 					response: &http.Response{ | ||||
| 						StatusCode: http.StatusOK, | ||||
| 					}, | ||||
| 					body: []byte("body"), | ||||
| 				}, | ||||
| 				expectedErr:  nil, | ||||
| 				expectedBody: []byte("body"), | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| }) | ||||
|  | @ -83,34 +83,34 @@ func Validate(o *options.Options) error { | |||
| 
 | ||||
| 			logger.Printf("Performing OIDC Discovery...") | ||||
| 
 | ||||
| 			if req, err := http.NewRequest("GET", strings.TrimSuffix(o.OIDCIssuerURL, "/")+"/.well-known/openid-configuration", nil); err == nil { | ||||
| 				if body, err := requests.Request(req); err == nil { | ||||
| 
 | ||||
| 					// Prefer manually configured URLs. It's a bit unclear
 | ||||
| 					// why you'd be doing discovery and also providing the URLs
 | ||||
| 					// explicitly though...
 | ||||
| 					if o.LoginURL == "" { | ||||
| 						o.LoginURL = body.Get("authorization_endpoint").MustString() | ||||
| 					} | ||||
| 
 | ||||
| 					if o.RedeemURL == "" { | ||||
| 						o.RedeemURL = body.Get("token_endpoint").MustString() | ||||
| 					} | ||||
| 
 | ||||
| 					if o.OIDCJwksURL == "" { | ||||
| 						o.OIDCJwksURL = body.Get("jwks_uri").MustString() | ||||
| 					} | ||||
| 
 | ||||
| 					if o.ProfileURL == "" { | ||||
| 						o.ProfileURL = body.Get("userinfo_endpoint").MustString() | ||||
| 					} | ||||
| 
 | ||||
| 					o.SkipOIDCDiscovery = true | ||||
| 				} else { | ||||
| 					logger.Printf("error: failed to discover OIDC configuration: %v", err) | ||||
| 				} | ||||
| 			requestURL := strings.TrimSuffix(o.OIDCIssuerURL, "/") + "/.well-known/openid-configuration" | ||||
| 			body, err := requests.New(requestURL). | ||||
| 				WithContext(ctx). | ||||
| 				Do(). | ||||
| 				UnmarshalJSON() | ||||
| 			if err != nil { | ||||
| 				logger.Printf("error: failed to discover OIDC configuration: %v", err) | ||||
| 			} else { | ||||
| 				logger.Printf("error: failed parsing OIDC discovery URL: %v", err) | ||||
| 				// Prefer manually configured URLs. It's a bit unclear
 | ||||
| 				// why you'd be doing discovery and also providing the URLs
 | ||||
| 				// explicitly though...
 | ||||
| 				if o.LoginURL == "" { | ||||
| 					o.LoginURL = body.Get("authorization_endpoint").MustString() | ||||
| 				} | ||||
| 
 | ||||
| 				if o.RedeemURL == "" { | ||||
| 					o.RedeemURL = body.Get("token_endpoint").MustString() | ||||
| 				} | ||||
| 
 | ||||
| 				if o.OIDCJwksURL == "" { | ||||
| 					o.OIDCJwksURL = body.Get("jwks_uri").MustString() | ||||
| 				} | ||||
| 
 | ||||
| 				if o.ProfileURL == "" { | ||||
| 					o.ProfileURL = body.Get("userinfo_endpoint").MustString() | ||||
| 				} | ||||
| 
 | ||||
| 				o.SkipOIDCDiscovery = true | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
|  | @ -385,10 +385,10 @@ func newVerifierFromJwtIssuer(jwtIssuer jwtIssuer) (*oidc.IDTokenVerifier, error | |||
| 	if err != nil { | ||||
| 		// Try as JWKS URI
 | ||||
| 		jwksURI := strings.TrimSuffix(jwtIssuer.issuerURI, "/") + "/.well-known/jwks.json" | ||||
| 		_, err := http.NewRequest("GET", jwksURI, nil) | ||||
| 		if err != nil { | ||||
| 		if err := requests.New(jwksURI).Do().Error(); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		verifier = oidc.NewVerifier(jwtIssuer.issuerURI, oidc.NewRemoteKeySet(context.Background(), jwksURI), config) | ||||
| 	} else { | ||||
| 		verifier = provider.Verifier(config) | ||||
|  |  | |||
|  | @ -3,10 +3,8 @@ package providers | |||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
|  | @ -91,39 +89,22 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s | |||
| 		params.Add("resource", p.ProtectedResource.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	var req *http.Request | ||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 
 | ||||
| 	var resp *http.Response | ||||
| 	resp, err = http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	var jsonResponse struct { | ||||
| 		AccessToken  string `json:"access_token"` | ||||
| 		RefreshToken string `json:"refresh_token"` | ||||
| 		ExpiresOn    int64  `json:"expires_on,string"` | ||||
| 		IDToken      string `json:"id_token"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &jsonResponse) | ||||
| 
 | ||||
| 	err = requests.New(p.RedeemURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithMethod("POST"). | ||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | ||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&jsonResponse) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	created := time.Now() | ||||
|  | @ -169,26 +150,22 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session | |||
| 	if s.AccessToken == "" { | ||||
| 		return "", errors.New("missing access token") | ||||
| 	} | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	req.Header = getAzureHeader(s.AccessToken) | ||||
| 
 | ||||
| 	json, err := requests.Request(req) | ||||
| 
 | ||||
| 	json, err := requests.New(p.ProfileURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getAzureHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	email, err = getEmailFromJSON(json) | ||||
| 
 | ||||
| 	if err == nil && email != "" { | ||||
| 		return email, err | ||||
| 	} | ||||
| 
 | ||||
| 	email, err = json.Get("userPrincipalName").String() | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		logger.Printf("failed making request %s", err) | ||||
| 		return "", err | ||||
|  |  | |||
|  | @ -2,7 +2,6 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 
 | ||||
|  | @ -85,15 +84,14 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | |||
| 			FullName string `json:"full_name"` | ||||
| 		} | ||||
| 	} | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", | ||||
| 		p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) | ||||
| 
 | ||||
| 	requestURL := p.ValidateURL.String() + "?access_token=" + s.AccessToken | ||||
| 	err := requests.New(requestURL). | ||||
| 		WithContext(ctx). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&emails) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("failed building request %s", err) | ||||
| 		return "", err | ||||
| 	} | ||||
| 	err = requests.RequestJSON(req, &emails) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("failed making request %s", err) | ||||
| 		logger.Printf("failed making request: %v", err) | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
|  | @ -101,15 +99,15 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | |||
| 		teamURL := &url.URL{} | ||||
| 		*teamURL = *p.ValidateURL | ||||
| 		teamURL.Path = "/2.0/teams" | ||||
| 		req, err = http.NewRequestWithContext(ctx, "GET", | ||||
| 			teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) | ||||
| 
 | ||||
| 		requestURL := teamURL.String() + "?role=member&access_token=" + s.AccessToken | ||||
| 
 | ||||
| 		err := requests.New(requestURL). | ||||
| 			WithContext(ctx). | ||||
| 			Do(). | ||||
| 			UnmarshalInto(&teams) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("failed building request %s", err) | ||||
| 			return "", err | ||||
| 		} | ||||
| 		err = requests.RequestJSON(req, &teams) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("failed requesting teams membership %s", err) | ||||
| 			logger.Printf("failed requesting teams membership: %v", err) | ||||
| 			return "", err | ||||
| 		} | ||||
| 		var found = false | ||||
|  | @ -129,20 +127,20 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses | |||
| 		repositoriesURL := &url.URL{} | ||||
| 		*repositoriesURL = *p.ValidateURL | ||||
| 		repositoriesURL.Path = "/2.0/repositories/" + strings.Split(p.Repository, "/")[0] | ||||
| 		req, err = http.NewRequestWithContext(ctx, "GET", | ||||
| 			repositoriesURL.String()+"?role=contributor"+ | ||||
| 				"&q=full_name="+url.QueryEscape("\""+p.Repository+"\"")+ | ||||
| 				"&access_token="+s.AccessToken, | ||||
| 			nil) | ||||
| 
 | ||||
| 		requestURL := repositoriesURL.String() + "?role=contributor" + | ||||
| 			"&q=full_name=" + url.QueryEscape("\""+p.Repository+"\"") + | ||||
| 			"&access_token=" + s.AccessToken | ||||
| 
 | ||||
| 		err := requests.New(requestURL). | ||||
| 			WithContext(ctx). | ||||
| 			Do(). | ||||
| 			UnmarshalInto(&repositories) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("failed building request %s", err) | ||||
| 			return "", err | ||||
| 		} | ||||
| 		err = requests.RequestJSON(req, &repositories) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("failed checking repository access %s", err) | ||||
| 			logger.Printf("failed checking repository access: %v", err) | ||||
| 			return "", err | ||||
| 		} | ||||
| 
 | ||||
| 		var found = false | ||||
| 		for _, repository := range repositories.Values { | ||||
| 			if p.Repository == repository.FullName { | ||||
|  |  | |||
|  | @ -60,13 +60,12 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. | |||
| 	if s.AccessToken == "" { | ||||
| 		return "", errors.New("missing access token") | ||||
| 	} | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String(), nil) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	req.Header = getDigitalOceanHeader(s.AccessToken) | ||||
| 
 | ||||
| 	json, err := requests.Request(req) | ||||
| 	json, err := requests.New(p.ProfileURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getDigitalOceanHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  |  | |||
|  | @ -62,20 +62,22 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | |||
| 	if s.AccessToken == "" { | ||||
| 		return "", errors.New("missing access token") | ||||
| 	} | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?fields=name,email", nil) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	req.Header = getFacebookHeader(s.AccessToken) | ||||
| 
 | ||||
| 	type result struct { | ||||
| 		Email string | ||||
| 	} | ||||
| 	var r result | ||||
| 	err = requests.RequestJSON(req, &r) | ||||
| 
 | ||||
| 	requestURL := p.ProfileURL.String() + "?fields=name,email" | ||||
| 	err := requests.New(requestURL). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getFacebookHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&r) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	if r.Email == "" { | ||||
| 		return "", errors.New("no email") | ||||
| 	} | ||||
|  |  | |||
|  | @ -2,10 +2,8 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"path" | ||||
|  | @ -15,6 +13,7 @@ import ( | |||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||
| ) | ||||
| 
 | ||||
| // GitHubProvider represents an GitHub based Identity Provider
 | ||||
|  | @ -111,27 +110,17 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, | |||
| 			Path:     path.Join(p.ValidateURL.Path, "/user/orgs"), | ||||
| 			RawQuery: params.Encode(), | ||||
| 		} | ||||
| 		req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | ||||
| 		req.Header = getGitHubHeader(accessToken) | ||||
| 		resp, err := http.DefaultClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return false, err | ||||
| 		} | ||||
| 
 | ||||
| 		body, err := ioutil.ReadAll(resp.Body) | ||||
| 		resp.Body.Close() | ||||
| 		if err != nil { | ||||
| 			return false, err | ||||
| 		} | ||||
| 		if resp.StatusCode != 200 { | ||||
| 			return false, fmt.Errorf( | ||||
| 				"got %d from %q %s", resp.StatusCode, endpoint.String(), body) | ||||
| 		} | ||||
| 
 | ||||
| 		var op orgsPage | ||||
| 		if err := json.Unmarshal(body, &op); err != nil { | ||||
| 		err := requests.New(endpoint.String()). | ||||
| 			WithContext(ctx). | ||||
| 			WithHeaders(getGitHubHeader(accessToken)). | ||||
| 			Do(). | ||||
| 			UnmarshalInto(&op) | ||||
| 		if err != nil { | ||||
| 			return false, err | ||||
| 		} | ||||
| 
 | ||||
| 		if len(op) == 0 { | ||||
| 			break | ||||
| 		} | ||||
|  | @ -187,11 +176,15 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | |||
| 			RawQuery: params.Encode(), | ||||
| 		} | ||||
| 
 | ||||
| 		req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | ||||
| 		req.Header = getGitHubHeader(accessToken) | ||||
| 		resp, err := http.DefaultClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return false, err | ||||
| 		// bodyclose cannot detect that the body is being closed later in requests.Into,
 | ||||
| 		// so have to skip the linting for the next line.
 | ||||
| 		// nolint:bodyclose
 | ||||
| 		result := requests.New(endpoint.String()). | ||||
| 			WithContext(ctx). | ||||
| 			WithHeaders(getGitHubHeader(accessToken)). | ||||
| 			Do() | ||||
| 		if result.Error() != nil { | ||||
| 			return false, result.Error() | ||||
| 		} | ||||
| 
 | ||||
| 		if last == 0 { | ||||
|  | @ -207,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | |||
| 			// link header at last page (doesn't exist last info)
 | ||||
| 			// <https://api.github.com/user/teams?page=3&per_page=10>; rel="prev", <https://api.github.com/user/teams?page=1&per_page=10>; rel="first"
 | ||||
| 
 | ||||
| 			link := resp.Header.Get("Link") | ||||
| 			link := result.Headers().Get("Link") | ||||
| 			rep1 := regexp.MustCompile(`(?s).*\<https://api.github.com/user/teams\?page=(.)&per_page=[0-9]+\>; rel="last".*`) | ||||
| 			i, converr := strconv.Atoi(rep1.ReplaceAllString(link, "$1")) | ||||
| 
 | ||||
|  | @ -217,21 +210,9 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) | |||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		body, err := ioutil.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			resp.Body.Close() | ||||
| 			return false, err | ||||
| 		} | ||||
| 		resp.Body.Close() | ||||
| 
 | ||||
| 		if resp.StatusCode != 200 { | ||||
| 			return false, fmt.Errorf( | ||||
| 				"got %d from %q %s", resp.StatusCode, endpoint.String(), body) | ||||
| 		} | ||||
| 
 | ||||
| 		var tp teamsPage | ||||
| 		if err := json.Unmarshal(body, &tp); err != nil { | ||||
| 			return false, fmt.Errorf("%s unmarshaling %s", err, body) | ||||
| 		if err := result.UnmarshalInto(&tp); err != nil { | ||||
| 			return false, err | ||||
| 		} | ||||
| 		if len(tp) == 0 { | ||||
| 			break | ||||
|  | @ -297,25 +278,13 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, | |||
| 		Path:   path.Join(p.ValidateURL.Path, "/repo/", p.Repo), | ||||
| 	} | ||||
| 
 | ||||
| 	req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | ||||
| 	req.Header = getGitHubHeader(accessToken) | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 
 | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return false, fmt.Errorf( | ||||
| 			"got %d from %q %s", resp.StatusCode, endpoint.String(), body) | ||||
| 	} | ||||
| 
 | ||||
| 	var repo repository | ||||
| 	if err := json.Unmarshal(body, &repo); err != nil { | ||||
| 	err := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(accessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&repo) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 
 | ||||
|  | @ -337,26 +306,15 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, | |||
| 		Host:   p.ValidateURL.Host, | ||||
| 		Path:   path.Join(p.ValidateURL.Path, "/user"), | ||||
| 	} | ||||
| 	req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | ||||
| 	req.Header = getGitHubHeader(accessToken) | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 
 | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	err := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(accessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&user) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return false, fmt.Errorf("got %d from %q %s", | ||||
| 			resp.StatusCode, stripToken(endpoint.String()), body) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := json.Unmarshal(body, &user); err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 
 | ||||
| 	if p.isVerifiedUser(user.Login) { | ||||
| 		return true, nil | ||||
|  | @ -372,24 +330,20 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok | |||
| 		Host:   p.ValidateURL.Host, | ||||
| 		Path:   path.Join(p.ValidateURL.Path, "/repos/", p.Repo, "/collaborators/", username), | ||||
| 	} | ||||
| 	req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | ||||
| 	req.Header = getGitHubHeader(accessToken) | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	result := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(accessToken)). | ||||
| 		Do() | ||||
| 	if result.Error() != nil { | ||||
| 		return false, result.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 204 { | ||||
| 	if result.StatusCode() != 204 { | ||||
| 		return false, fmt.Errorf("got %d from %q %s", | ||||
| 			resp.StatusCode, endpoint.String(), body) | ||||
| 			result.StatusCode(), endpoint.String(), result.Body()) | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) | ||||
| 	logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body()) | ||||
| 
 | ||||
| 	return true, nil | ||||
| } | ||||
|  | @ -440,28 +394,14 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio | |||
| 		Host:   p.ValidateURL.Host, | ||||
| 		Path:   path.Join(p.ValidateURL.Path, "/user/emails"), | ||||
| 	} | ||||
| 	req, _ := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | ||||
| 	req.Header = getGitHubHeader(s.AccessToken) | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	err := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&emails) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return "", fmt.Errorf("got %d from %q %s", | ||||
| 			resp.StatusCode, endpoint.String(), body) | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) | ||||
| 
 | ||||
| 	if err := json.Unmarshal(body, &emails); err != nil { | ||||
| 		return "", fmt.Errorf("%s unmarshaling %s", err, body) | ||||
| 	} | ||||
| 
 | ||||
| 	returnEmail := "" | ||||
| 	for _, email := range emails { | ||||
|  | @ -489,34 +429,15 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta | |||
| 		Path:   path.Join(p.ValidateURL.Path, "/user"), | ||||
| 	} | ||||
| 
 | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", endpoint.String(), nil) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("could not create new GET request: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	req.Header = getGitHubHeader(s.AccessToken) | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	err := requests.New(endpoint.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getGitHubHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&user) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	body, err := ioutil.ReadAll(resp.Body) | ||||
| 	defer resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return "", fmt.Errorf("got %d from %q %s", | ||||
| 			resp.StatusCode, endpoint.String(), body) | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) | ||||
| 
 | ||||
| 	if err := json.Unmarshal(body, &user); err != nil { | ||||
| 		return "", fmt.Errorf("%s unmarshaling %s", err, body) | ||||
| 	} | ||||
| 
 | ||||
| 	// Now that we have the username we can check collaborator status
 | ||||
| 	if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" { | ||||
| 		if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok { | ||||
|  |  | |||
|  | @ -2,15 +2,13 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	oidc "github.com/coreos/go-oidc" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
|  | @ -131,31 +129,14 @@ func (p *GitLabProvider) getUserInfo(ctx context.Context, s *sessions.SessionSta | |||
| 	userInfoURL := *p.LoginURL | ||||
| 	userInfoURL.Path = "/oauth/userinfo" | ||||
| 
 | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL.String(), nil) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create user info request: %v", err) | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", "Bearer "+s.AccessToken) | ||||
| 
 | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to perform user info request: %v", err) | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to read user info response: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		return nil, fmt.Errorf("got %d during user info request: %s", resp.StatusCode, body) | ||||
| 	} | ||||
| 
 | ||||
| 	var userInfo gitlabUserInfo | ||||
| 	err = json.Unmarshal(body, &userInfo) | ||||
| 	err := requests.New(userInfoURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		SetHeader("Authorization", "Bearer "+s.AccessToken). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&userInfo) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to parse user info: %v", err) | ||||
| 		return nil, fmt.Errorf("error getting user info: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return &userInfo, nil | ||||
|  |  | |||
|  | @ -9,13 +9,13 @@ import ( | |||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||
| 	"golang.org/x/oauth2/google" | ||||
| 	admin "google.golang.org/api/admin/directory/v1" | ||||
| 	"google.golang.org/api/googleapi" | ||||
|  | @ -116,28 +116,6 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | |||
| 	params.Add("client_secret", clientSecret) | ||||
| 	params.Add("code", code) | ||||
| 	params.Add("grant_type", "authorization_code") | ||||
| 	var req *http.Request | ||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 
 | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	var jsonResponse struct { | ||||
| 		AccessToken  string `json:"access_token"` | ||||
|  | @ -145,10 +123,18 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | |||
| 		ExpiresIn    int64  `json:"expires_in"` | ||||
| 		IDToken      string `json:"id_token"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &jsonResponse) | ||||
| 
 | ||||
| 	err = requests.New(p.RedeemURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithMethod("POST"). | ||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | ||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&jsonResponse) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	c, err := claimsFromIDToken(jsonResponse.IDToken) | ||||
| 	if err != nil { | ||||
| 		return | ||||
|  | @ -283,38 +269,24 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st | |||
| 	params.Add("client_secret", clientSecret) | ||||
| 	params.Add("refresh_token", refreshToken) | ||||
| 	params.Add("grant_type", "refresh_token") | ||||
| 	var req *http.Request | ||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 
 | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	var data struct { | ||||
| 		AccessToken string `json:"access_token"` | ||||
| 		ExpiresIn   int64  `json:"expires_in"` | ||||
| 		IDToken     string `json:"id_token"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &data) | ||||
| 
 | ||||
| 	err = requests.New(p.RedeemURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithMethod("POST"). | ||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | ||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&data) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return "", "", 0, err | ||||
| 	} | ||||
| 
 | ||||
| 	token = data.AccessToken | ||||
| 	idToken = data.IDToken | ||||
| 	expires = time.Duration(data.ExpiresIn) * time.Second | ||||
|  |  | |||
|  | @ -2,7 +2,6 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 
 | ||||
|  | @ -56,20 +55,22 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h | |||
| 		params := url.Values{"access_token": {accessToken}} | ||||
| 		endpoint = endpoint + "?" + params.Encode() | ||||
| 	} | ||||
| 	resp, err := requests.RequestUnparsedResponse(ctx, endpoint, header) | ||||
| 	if err != nil { | ||||
| 
 | ||||
| 	result := requests.New(endpoint). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(header). | ||||
| 		Do() | ||||
| 	if result.Error() != nil { | ||||
| 		logger.Printf("GET %s", stripToken(endpoint)) | ||||
| 		logger.Printf("token validation request failed: %s", err) | ||||
| 		logger.Printf("token validation request failed: %s", result.Error()) | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	body, _ := ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	logger.Printf("%d GET %s %s", resp.StatusCode, stripToken(endpoint), body) | ||||
| 	logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body()) | ||||
| 
 | ||||
| 	if resp.StatusCode == 200 { | ||||
| 	if result.StatusCode() == 200 { | ||||
| 		return true | ||||
| 	} | ||||
| 	logger.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) | ||||
| 	logger.Printf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()) | ||||
| 	return false | ||||
| } | ||||
|  |  | |||
|  | @ -2,7 +2,6 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
|  | @ -51,14 +50,11 @@ func (p *KeycloakProvider) SetGroup(group string) { | |||
| } | ||||
| 
 | ||||
| func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||
| 
 | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ValidateURL.String(), nil) | ||||
| 	req.Header.Set("Authorization", "Bearer "+s.AccessToken) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("failed building request %s", err) | ||||
| 		return "", err | ||||
| 	} | ||||
| 	json, err := requests.Request(req) | ||||
| 	json, err := requests.New(p.ValidateURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		SetHeader("Authorization", "Bearer "+s.AccessToken). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
| 		logger.Printf("failed making request %s", err) | ||||
| 		return "", err | ||||
|  |  | |||
|  | @ -58,13 +58,13 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | |||
| 	if s.AccessToken == "" { | ||||
| 		return "", errors.New("missing access token") | ||||
| 	} | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", p.ProfileURL.String()+"?format=json", nil) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	req.Header = getLinkedInHeader(s.AccessToken) | ||||
| 
 | ||||
| 	json, err := requests.Request(req) | ||||
| 	requestURL := p.ProfileURL.String() + "?format=json" | ||||
| 	json, err := requests.New(requestURL). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getLinkedInHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  |  | |||
|  | @ -15,6 +15,7 @@ import ( | |||
| 
 | ||||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||
| 	"gopkg.in/square/go-jose.v2" | ||||
| ) | ||||
| 
 | ||||
|  | @ -128,51 +129,34 @@ func checkNonce(idToken string, p *LoginGovProvider) (err error) { | |||
| 	return | ||||
| } | ||||
| 
 | ||||
| func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (email string, err error) { | ||||
| 	// query the user info endpoint for user attributes
 | ||||
| 	var req *http.Request | ||||
| 	req, err = http.NewRequestWithContext(ctx, "GET", userInfoEndpoint, nil) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	req.Header.Set("Authorization", "Bearer "+accessToken) | ||||
| 
 | ||||
| 	resp, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint string) (string, error) { | ||||
| 	// parse the user attributes from the data we got and make sure that
 | ||||
| 	// the email address has been validated.
 | ||||
| 	var emailData struct { | ||||
| 		Email         string `json:"email"` | ||||
| 		EmailVerified bool   `json:"email_verified"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &emailData) | ||||
| 
 | ||||
| 	// query the user info endpoint for user attributes
 | ||||
| 	err := requests.New(userInfoEndpoint). | ||||
| 		WithContext(ctx). | ||||
| 		SetHeader("Authorization", "Bearer "+accessToken). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&emailData) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return "", err | ||||
| 	} | ||||
| 	if emailData.Email == "" { | ||||
| 		err = fmt.Errorf("missing email") | ||||
| 		return | ||||
| 
 | ||||
| 	email := emailData.Email | ||||
| 	if email == "" { | ||||
| 		return "", fmt.Errorf("missing email") | ||||
| 	} | ||||
| 	email = emailData.Email | ||||
| 
 | ||||
| 	if !emailData.EmailVerified { | ||||
| 		err = fmt.Errorf("email %s not listed as verified", email) | ||||
| 		return | ||||
| 		return "", fmt.Errorf("email %s not listed as verified", email) | ||||
| 	} | ||||
| 	return | ||||
| 
 | ||||
| 	return email, nil | ||||
| } | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
|  | @ -201,30 +185,6 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | |||
| 	params.Add("code", code) | ||||
| 	params.Add("grant_type", "authorization_code") | ||||
| 
 | ||||
| 	var req *http.Request | ||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 
 | ||||
| 	var resp *http.Response | ||||
| 	resp, err = http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// Get the token from the body that we got from the token endpoint.
 | ||||
| 	var jsonResponse struct { | ||||
| 		AccessToken string `json:"access_token"` | ||||
|  | @ -232,9 +192,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | |||
| 		TokenType   string `json:"token_type"` | ||||
| 		ExpiresIn   int64  `json:"expires_in"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &jsonResponse) | ||||
| 	err = requests.New(p.RedeemURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithMethod("POST"). | ||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | ||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&jsonResponse) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// check nonce here
 | ||||
|  |  | |||
|  | @ -6,7 +6,6 @@ import ( | |||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||
| ) | ||||
| 
 | ||||
|  | @ -31,18 +30,15 @@ func getNextcloudHeader(accessToken string) http.Header { | |||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { | ||||
| 	req, err := http.NewRequestWithContext(ctx, "GET", | ||||
| 		p.ValidateURL.String(), nil) | ||||
| 	json, err := requests.New(p.ValidateURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithHeaders(getNextcloudHeader(s.AccessToken)). | ||||
| 		Do(). | ||||
| 		UnmarshalJSON() | ||||
| 	if err != nil { | ||||
| 		logger.Printf("failed building request %s", err) | ||||
| 		return "", err | ||||
| 	} | ||||
| 	req.Header = getNextcloudHeader(s.AccessToken) | ||||
| 	json, err := requests.Request(req) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("failed making request %s", err) | ||||
| 		return "", err | ||||
| 		return "", fmt.Errorf("error making request: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	email, err := json.Get("ocs").Get("data").Get("email").String() | ||||
| 	return email, err | ||||
| } | ||||
|  |  | |||
|  | @ -256,13 +256,11 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. | |||
| 		// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
 | ||||
| 		// contents at the profileURL contains the email.
 | ||||
| 		// Make a query to the userinfo endpoint, and attempt to locate the email from there.
 | ||||
| 		req, err := http.NewRequestWithContext(ctx, "GET", profileURL, nil) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		req.Header = getOIDCHeader(accessToken) | ||||
| 
 | ||||
| 		respJSON, err := requests.Request(req) | ||||
| 		respJSON, err := requests.New(profileURL). | ||||
| 			WithContext(ctx). | ||||
| 			WithHeaders(getOIDCHeader(accessToken)). | ||||
| 			Do(). | ||||
| 			UnmarshalJSON() | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  |  | |||
|  | @ -3,17 +3,15 @@ package providers | |||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/requests" | ||||
| ) | ||||
| 
 | ||||
| var _ Provider = (*ProviderData)(nil) | ||||
|  | @ -39,35 +37,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | |||
| 		params.Add("resource", p.ProtectedResource.String()) | ||||
| 	} | ||||
| 
 | ||||
| 	var req *http.Request | ||||
| 	req, err = http.NewRequestWithContext(ctx, "POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
| 
 | ||||
| 	var resp *http.Response | ||||
| 	resp, err = http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	var body []byte | ||||
| 	body, err = ioutil.ReadAll(resp.Body) | ||||
| 	resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if resp.StatusCode != 200 { | ||||
| 		err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) | ||||
| 		return | ||||
| 	result := requests.New(p.RedeemURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithMethod("POST"). | ||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | ||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||
| 		Do() | ||||
| 	if result.Error() != nil { | ||||
| 		return nil, result.Error() | ||||
| 	} | ||||
| 
 | ||||
| 	// blindly try json and x-www-form-urlencoded
 | ||||
| 	var jsonResponse struct { | ||||
| 		AccessToken string `json:"access_token"` | ||||
| 	} | ||||
| 	err = json.Unmarshal(body, &jsonResponse) | ||||
| 	err = result.UnmarshalInto(&jsonResponse) | ||||
| 	if err == nil { | ||||
| 		s = &sessions.SessionState{ | ||||
| 			AccessToken: jsonResponse.AccessToken, | ||||
|  | @ -76,7 +60,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | |||
| 	} | ||||
| 
 | ||||
| 	var v url.Values | ||||
| 	v, err = url.ParseQuery(string(body)) | ||||
| 	v, err = url.ParseQuery(string(result.Body())) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | @ -84,7 +68,7 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s | |||
| 		created := time.Now() | ||||
| 		s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} | ||||
| 	} else { | ||||
| 		err = fmt.Errorf("no access token found %s", body) | ||||
| 		err = fmt.Errorf("no access token found %s", result.Body()) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue