From e932381ba789bbbad86ca5001f7b34c413281009 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Tue, 26 May 2020 19:56:10 +0100 Subject: [PATCH 1/5] Add LegacyOptions and conversion to new Options This will be temporary until we switch to structured config, then we can remove the LegacyOptions and conversions --- pkg/apis/options/legacy_options.go | 98 ++++++++++++ pkg/apis/options/legacy_options_test.go | 194 ++++++++++++++++++++++++ pkg/apis/options/load_test.go | 11 +- pkg/apis/options/options.go | 48 +++--- pkg/apis/options/options_suite_test.go | 16 ++ pkg/validation/options.go | 12 +- pkg/validation/options_test.go | 26 +--- 7 files changed, 339 insertions(+), 66 deletions(-) create mode 100644 pkg/apis/options/legacy_options.go create mode 100644 pkg/apis/options/legacy_options_test.go create mode 100644 pkg/apis/options/options_suite_test.go diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go new file mode 100644 index 00000000..b0198c9c --- /dev/null +++ b/pkg/apis/options/legacy_options.go @@ -0,0 +1,98 @@ +package options + +import ( + "fmt" + "net/url" + "strconv" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +type LegacyOptions struct { + // Legacy options related to upstream servers + LegacyFlushInterval time.Duration `flag:"flush-interval" cfg:"flush_interval"` + LegacyPassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` + LegacyProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets"` + LegacySSLUpstreamInsecureSkipVerify bool `flag:"ssl-upstream-insecure-skip-verify" cfg:"ssl_upstream_insecure_skip_verify"` + LegacyUpstreams []string `flag:"upstream" cfg:"upstreams"` + + Options Options `cfg:",squash"` +} + +func NewLegacyOptions() *LegacyOptions { + return &LegacyOptions{ + LegacyPassHostHeader: true, + LegacyProxyWebSockets: true, + LegacyFlushInterval: time.Duration(1) * time.Second, + + Options: *NewOptions(), + } +} + +func (l *LegacyOptions) ToOptions() (*Options, error) { + upstreams, err := convertLegacyUpstreams(l.LegacyUpstreams, l.LegacySSLUpstreamInsecureSkipVerify, l.LegacyPassHostHeader, l.LegacyProxyWebSockets, l.LegacyFlushInterval) + if err != nil { + return nil, fmt.Errorf("error converting upstreams: %v", err) + } + l.Options.UpstreamServers = upstreams + + return &l.Options, nil +} + +func convertLegacyUpstreams(upstreamStrings []string, skipVerify, passHostHeader, proxyWebSockets bool, flushInterval time.Duration) (Upstreams, error) { + upstreams := Upstreams{} + + for _, upstreamString := range upstreamStrings { + u, err := url.Parse(upstreamString) + if err != nil { + return nil, fmt.Errorf("could not parse upstream %q: %v", upstreamString, err) + } + + if u.Path == "" { + u.Path = "/" + } + + upstream := Upstream{ + ID: u.Path, + Path: u.Path, + URI: upstreamString, + InsecureSkipTLSVerify: skipVerify, + PassHostHeader: passHostHeader, + ProxyWebSockets: proxyWebSockets, + FlushInterval: &flushInterval, + } + + switch u.Scheme { + case "file": + if u.Fragment != "" { + upstream.ID = u.Fragment + upstream.Path = u.Fragment + } + case "static": + responseCode, err := strconv.Atoi(u.Host) + if err != nil { + logger.Printf("unable to convert %q to int, use default \"200\"", u.Host) + responseCode = 200 + } + upstream.Static = true + upstream.StaticCode = &responseCode + + // These are not allowed to be empty and must be unique + upstream.ID = upstreamString + upstream.Path = upstreamString + + // Force defaults compatible with static responses + upstream.URI = "" + upstream.InsecureSkipTLSVerify = false + upstream.PassHostHeader = true + upstream.ProxyWebSockets = false + flush := 1 * time.Second + upstream.FlushInterval = &flush + } + + upstreams = append(upstreams, upstream) + } + + return upstreams, nil +} diff --git a/pkg/apis/options/legacy_options_test.go b/pkg/apis/options/legacy_options_test.go new file mode 100644 index 00000000..0e221749 --- /dev/null +++ b/pkg/apis/options/legacy_options_test.go @@ -0,0 +1,194 @@ +package options + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("Legacy Options", func() { + Context("ToOptions", func() { + It("converts the options as expected", func() { + opts := NewOptions() + + legacyOpts := NewLegacyOptions() + + // Set upstreams and related options to test their conversion + flushInterval := 5 * time.Second + legacyOpts.LegacyFlushInterval = flushInterval + legacyOpts.LegacyPassHostHeader = true + legacyOpts.LegacyProxyWebSockets = true + legacyOpts.LegacySSLUpstreamInsecureSkipVerify = true + legacyOpts.LegacyUpstreams = []string{"http://foo.bar/baz", "file://var/lib/website#/bar"} + + opts.UpstreamServers = Upstreams{ + { + ID: "/baz", + Path: "/baz", + URI: "http://foo.bar/baz", + FlushInterval: &flushInterval, + InsecureSkipTLSVerify: true, + PassHostHeader: true, + ProxyWebSockets: true, + }, + { + ID: "/bar", + Path: "/bar", + URI: "file://var/lib/website#/bar", + FlushInterval: &flushInterval, + InsecureSkipTLSVerify: true, + PassHostHeader: true, + ProxyWebSockets: true, + }, + } + + converted, err := legacyOpts.ToOptions() + Expect(err).ToNot(HaveOccurred()) + Expect(converted).To(Equal(opts)) + }) + }) + + Context("Legacy Upstreams", func() { + type convertUpstreamsTableInput struct { + upstreamStrings []string + expectedUpstreams Upstreams + errMsg string + } + + defaultFlushInterval := 1 * time.Second + + // Non defaults for these options + skipVerify := true + passHostHeader := false + proxyWebSockets := true + flushInterval := 5 * time.Second + + // Test cases and expected outcomes + validHTTP := "http://foo.bar/baz" + validHTTPUpstream := Upstream{ + ID: "/baz", + Path: "/baz", + URI: validHTTP, + InsecureSkipTLSVerify: skipVerify, + PassHostHeader: passHostHeader, + ProxyWebSockets: proxyWebSockets, + FlushInterval: &flushInterval, + } + + // Test cases and expected outcomes + emptyPathHTTP := "http://foo.bar" + emptyPathHTTPUpstream := Upstream{ + ID: "/", + Path: "/", + URI: emptyPathHTTP, + InsecureSkipTLSVerify: skipVerify, + PassHostHeader: passHostHeader, + ProxyWebSockets: proxyWebSockets, + FlushInterval: &flushInterval, + } + + validFileWithFragment := "file://var/lib/website#/bar" + validFileWithFragmentUpstream := Upstream{ + ID: "/bar", + Path: "/bar", + URI: validFileWithFragment, + InsecureSkipTLSVerify: skipVerify, + PassHostHeader: passHostHeader, + ProxyWebSockets: proxyWebSockets, + FlushInterval: &flushInterval, + } + + validStatic := "static://204" + validStaticCode := 204 + validStaticUpstream := Upstream{ + ID: validStatic, + Path: validStatic, + URI: "", + Static: true, + StaticCode: &validStaticCode, + InsecureSkipTLSVerify: false, + PassHostHeader: true, + ProxyWebSockets: false, + FlushInterval: &defaultFlushInterval, + } + + invalidStatic := "static://abc" + invalidStaticCode := 200 + invalidStaticUpstream := Upstream{ + ID: invalidStatic, + Path: invalidStatic, + URI: "", + Static: true, + StaticCode: &invalidStaticCode, + InsecureSkipTLSVerify: false, + PassHostHeader: true, + ProxyWebSockets: false, + FlushInterval: &defaultFlushInterval, + } + + invalidHTTP := ":foo" + invalidHTTPErrMsg := "could not parse upstream \":foo\": parse \":foo\": missing protocol scheme" + + DescribeTable("convertLegacyUpstreams", + func(o *convertUpstreamsTableInput) { + upstreams, err := convertLegacyUpstreams(o.upstreamStrings, skipVerify, passHostHeader, proxyWebSockets, flushInterval) + + if o.errMsg != "" { + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal(o.errMsg)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + + Expect(upstreams).To(ConsistOf(o.expectedUpstreams)) + }, + Entry("with no upstreams", &convertUpstreamsTableInput{ + upstreamStrings: []string{}, + expectedUpstreams: Upstreams{}, + errMsg: "", + }), + Entry("with a valid HTTP upstream", &convertUpstreamsTableInput{ + upstreamStrings: []string{validHTTP}, + expectedUpstreams: Upstreams{validHTTPUpstream}, + errMsg: "", + }), + Entry("with a HTTP upstream with an empty path", &convertUpstreamsTableInput{ + upstreamStrings: []string{emptyPathHTTP}, + expectedUpstreams: Upstreams{emptyPathHTTPUpstream}, + errMsg: "", + }), + Entry("with a valid File upstream with a fragment", &convertUpstreamsTableInput{ + upstreamStrings: []string{validFileWithFragment}, + expectedUpstreams: Upstreams{validFileWithFragmentUpstream}, + errMsg: "", + }), + Entry("with a valid static upstream", &convertUpstreamsTableInput{ + upstreamStrings: []string{validStatic}, + expectedUpstreams: Upstreams{validStaticUpstream}, + errMsg: "", + }), + Entry("with an invalid static upstream, code is 200", &convertUpstreamsTableInput{ + upstreamStrings: []string{invalidStatic}, + expectedUpstreams: Upstreams{invalidStaticUpstream}, + errMsg: "", + }), + Entry("with an invalid HTTP upstream", &convertUpstreamsTableInput{ + upstreamStrings: []string{invalidHTTP}, + expectedUpstreams: Upstreams{}, + errMsg: invalidHTTPErrMsg, + }), + Entry("with an invalid HTTP upstream and other upstreams", &convertUpstreamsTableInput{ + upstreamStrings: []string{validHTTP, invalidHTTP}, + expectedUpstreams: Upstreams{}, + errMsg: invalidHTTPErrMsg, + }), + Entry("with multiple valid upstreams", &convertUpstreamsTableInput{ + upstreamStrings: []string{validHTTP, validFileWithFragment, validStatic}, + expectedUpstreams: Upstreams{validHTTPUpstream, validFileWithFragmentUpstream, validStaticUpstream}, + errMsg: "", + }), + ) + }) +}) diff --git a/pkg/apis/options/load_test.go b/pkg/apis/options/load_test.go index 0e61d707..9f6f5d4f 100644 --- a/pkg/apis/options/load_test.go +++ b/pkg/apis/options/load_test.go @@ -4,7 +4,6 @@ import ( "fmt" "io/ioutil" "os" - "testing" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" @@ -12,11 +11,6 @@ import ( "github.com/spf13/pflag" ) -func TestOptionsSuite(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Options Suite") -} - var _ = Describe("Load", func() { Context("with a testOptions structure", func() { type TestOptionSubStruct struct { @@ -300,6 +294,11 @@ var _ = Describe("Load", func() { input: &Options{}, expectedOutput: NewOptions(), }), + Entry("with an empty LegacyOptions struct, should return default values", &testOptionsTableInput{ + flagSet: NewFlagSet, + input: &LegacyOptions{}, + expectedOutput: NewLegacyOptions(), + }), ) }) }) diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 5e26ea1e..325c263c 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -24,7 +24,6 @@ type Options struct { ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"` PingPath string `flag:"ping-path" cfg:"ping_path"` PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"` - ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets"` HTTPAddress string `flag:"http-address" cfg:"http_address"` HTTPSAddress string `flag:"https-address" cfg:"https_address"` ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"` @@ -64,26 +63,26 @@ type Options struct { Session SessionOptions `cfg:",squash"` Logging Logging `cfg:",squash"` - Upstreams []string `flag:"upstream" cfg:"upstreams"` - SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` - SkipAuthStripHeaders bool `flag:"skip-auth-strip-headers" cfg:"skip_auth_strip_headers"` - SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` - ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers"` - PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` - SetBasicAuth bool `flag:"set-basic-auth" cfg:"set_basic_auth"` - PreferEmailToUser bool `flag:"prefer-email-to-user" cfg:"prefer_email_to_user"` - BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"` - PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"` - PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` - SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` - PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"` - SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` - SSLUpstreamInsecureSkipVerify bool `flag:"ssl-upstream-insecure-skip-verify" cfg:"ssl_upstream_insecure_skip_verify"` - SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` - SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"` - PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header"` - SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` - FlushInterval time.Duration `flag:"flush-interval" cfg:"flush_interval"` + // Not used in the legacy config, name not allowed to match an external key (upstreams) + // TODO(JoelSpeed): Rename when legacy config is removed + UpstreamServers Upstreams `cfg:",internal"` + + SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` + SkipAuthStripHeaders bool `flag:"skip-auth-strip-headers" cfg:"skip_auth_strip_headers"` + SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` + ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers"` + PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` + SetBasicAuth bool `flag:"set-basic-auth" cfg:"set_basic_auth"` + PreferEmailToUser bool `flag:"prefer-email-to-user" cfg:"prefer_email_to_user"` + BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"` + PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"` + SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` + PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"` + SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` + SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` + SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"` + PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header"` + SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` // These options allow for other providers besides Google, with // potential overrides. @@ -114,7 +113,6 @@ type Options struct { // internal values that are set after config validation redirectURL *url.URL - proxyURLs []*url.URL compiledRegex []*regexp.Regexp provider providers.Provider signatureData *SignatureData @@ -125,7 +123,6 @@ type Options struct { // Options for Getting internal values func (o *Options) GetRedirectURL() *url.URL { return o.redirectURL } -func (o *Options) GetProxyURLs() []*url.URL { return o.proxyURLs } func (o *Options) GetCompiledRegex() []*regexp.Regexp { return o.compiledRegex } func (o *Options) GetProvider() providers.Provider { return o.provider } func (o *Options) GetSignatureData() *SignatureData { return o.signatureData } @@ -135,7 +132,6 @@ func (o *Options) GetRealClientIPParser() ipapi.RealClientIPParser { return o.re // Options for Setting internal values func (o *Options) SetRedirectURL(s *url.URL) { o.redirectURL = s } -func (o *Options) SetProxyURLs(s []*url.URL) { o.proxyURLs = s } func (o *Options) SetCompiledRegex(s []*regexp.Regexp) { o.compiledRegex = s } func (o *Options) SetProvider(s providers.Provider) { o.provider = s } func (o *Options) SetSignatureData(s *SignatureData) { o.signatureData = s } @@ -149,7 +145,6 @@ func NewOptions() *Options { ProxyPrefix: "/oauth2", ProviderType: "google", PingPath: "/ping", - ProxyWebSockets: true, HTTPAddress: "127.0.0.1:4180", HTTPSAddress: ":443", RealClientIPHeader: "X-Real-IP", @@ -160,13 +155,10 @@ func NewOptions() *Options { AzureTenant: "common", SetXAuthRequest: false, SkipAuthPreflight: false, - SkipAuthStripHeaders: false, - FlushInterval: time.Duration(1) * time.Second, PassBasicAuth: true, SetBasicAuth: false, PassUserHeaders: true, PassAccessToken: false, - PassHostHeader: true, SetAuthorization: false, PassAuthorization: false, PreferEmailToUser: false, diff --git a/pkg/apis/options/options_suite_test.go b/pkg/apis/options/options_suite_test.go new file mode 100644 index 00000000..a25cbe42 --- /dev/null +++ b/pkg/apis/options/options_suite_test.go @@ -0,0 +1,16 @@ +package options + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestOptionsSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Options Suite") +} diff --git a/pkg/validation/options.go b/pkg/validation/options.go index d0f4ba06..21601d5a 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -176,17 +176,7 @@ func Validate(o *options.Options) error { redirectURL, msgs = parseURL(o.RawRedirectURL, "redirect", msgs) o.SetRedirectURL(redirectURL) - for _, u := range o.Upstreams { - upstreamURL, err := url.Parse(u) - if err != nil { - msgs = append(msgs, fmt.Sprintf("error parsing upstream: %s", err)) - } else { - if upstreamURL.Path == "" { - upstreamURL.Path = "/" - } - o.SetProxyURLs(append(o.GetProxyURLs(), upstreamURL)) - } - } + msgs = append(msgs, validateUpstreams(o.UpstreamServers)...) for _, u := range o.SkipAuthRegex { compiledRegex, err := regexp.Compile(u) diff --git a/pkg/validation/options_test.go b/pkg/validation/options_test.go index 15761a28..8c9a892f 100644 --- a/pkg/validation/options_test.go +++ b/pkg/validation/options_test.go @@ -22,7 +22,11 @@ const ( func testOptions() *options.Options { o := options.NewOptions() - o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8080/") + o.UpstreamServers = append(o.UpstreamServers, options.Upstream{ + ID: "upstream", + Path: "/", + URI: "http://127.0.0.1:8080/", + }) o.Cookie.Secret = cookieSecret o.ClientID = clientID o.ClientSecret = clientSecret @@ -140,26 +144,6 @@ func TestRedirectURL(t *testing.T) { assert.Equal(t, expected, o.GetRedirectURL()) } -func TestProxyURLs(t *testing.T) { - o := testOptions() - o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") - assert.Equal(t, nil, Validate(o)) - expected := []*url.URL{ - {Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, - // note the '/' was added - {Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, - } - assert.Equal(t, expected, o.GetProxyURLs()) -} - -func TestProxyURLsError(t *testing.T) { - o := testOptions() - o.Upstreams = append(o.Upstreams, "127.0.0.1:8081") - err := Validate(o) - assert.NotEqual(t, nil, err) - assert.Contains(t, err.Error(), "error parsing upstream") -} - func TestCompiledRegex(t *testing.T) { o := testOptions() regexps := []string{"/foo/.*", "/ba[rz]/quux"} From 5dbcd737229e0b12a2613e6455d609e2428d6ab6 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Tue, 26 May 2020 20:06:27 +0100 Subject: [PATCH 2/5] Configure OAuth2 Proxy to use new upstreams package and LegacyConfig --- main.go | 10 ++- oauthproxy.go | 185 ++-------------------------------------- oauthproxy_test.go | 204 ++++++++++++--------------------------------- 3 files changed, 70 insertions(+), 329 deletions(-) diff --git a/main.go b/main.go index b42c2c05..24d48072 100644 --- a/main.go +++ b/main.go @@ -32,13 +32,19 @@ func main() { return } - opts := options.NewOptions() - err := options.Load(*config, flagSet, opts) + legacyOpts := options.NewLegacyOptions() + err := options.Load(*config, flagSet, legacyOpts) if err != nil { logger.Printf("ERROR: Failed to load config: %v", err) os.Exit(1) } + opts, err := legacyOpts.ToOptions() + if err != nil { + logger.Printf("ERROR: Failed to convert config: %v", err) + os.Exit(1) + } + err = validation.Validate(opts) if err != nil { logger.Printf("%s", err) diff --git a/oauthproxy.go b/oauthproxy.go index 4b310b9b..034fe6a3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/tls" b64 "encoding/base64" "encoding/json" "errors" @@ -10,15 +9,12 @@ import ( "html/template" "net" "net/http" - "net/http/httputil" "net/url" "regexp" - "strconv" "strings" "time" "github.com/coreos/go-oidc" - "github.com/mbland/hmacauth" ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" @@ -28,37 +24,17 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/providers" - "github.com/yhat/wsutil" ) const ( - // SignatureHeader is the name of the request header containing the GAP Signature - // Part of hmacauth - SignatureHeader = "GAP-Signature" - httpScheme = "http" httpsScheme = "https" applicationJSON = "application/json" ) -// SignatureHeaders contains the headers to be signed by the hmac algorithm -// Part of hmacauth -var SignatureHeaders = []string{ - "Content-Length", - "Content-Md5", - "Content-Type", - "Date", - "Authorization", - "X-Forwarded-User", - "X-Forwarded-Email", - "X-Forwarded-Preferred-User", - "X-Forwarded-Access-Token", - "Cookie", - "Gap-Auth", -} - var ( // ErrNeedsLogin means the user should be redirected to the login page ErrNeedsLogin = errors.New("redirect to login page") @@ -124,116 +100,6 @@ type OAuthProxy struct { Footer string } -// UpstreamProxy represents an upstream server to proxy to -type UpstreamProxy struct { - upstream string - handler http.Handler - wsHandler http.Handler - auth hmacauth.HmacAuth -} - -// ServeHTTP proxies requests to the upstream provider while signing the -// request headers -func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Header().Set("GAP-Upstream-Address", u.upstream) - if u.auth != nil { - r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) - u.auth.SignRequest(r) - } - if u.wsHandler != nil && strings.EqualFold(r.Header.Get("Connection"), "upgrade") && r.Header.Get("Upgrade") == "websocket" { - u.wsHandler.ServeHTTP(w, r) - } else { - u.handler.ServeHTTP(w, r) - } - -} - -// NewReverseProxy creates a new reverse proxy for proxying requests to upstream -// servers -func NewReverseProxy(target *url.URL, opts *options.Options) (proxy *httputil.ReverseProxy) { - proxy = httputil.NewSingleHostReverseProxy(target) - proxy.FlushInterval = opts.FlushInterval - if opts.SSLUpstreamInsecureSkipVerify { - proxy.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - } - setProxyErrorHandler(proxy, opts) - return proxy -} - -func setProxyErrorHandler(proxy *httputil.ReverseProxy, opts *options.Options) { - templates := loadTemplates(opts.CustomTemplatesDir) - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, proxyErr error) { - logger.Printf("Error proxying to upstream server: %v", proxyErr) - w.WriteHeader(http.StatusBadGateway) - data := struct { - Title string - Message string - ProxyPrefix string - }{ - Title: "Bad Gateway", - Message: "Error proxying to upstream server", - ProxyPrefix: opts.ProxyPrefix, - } - templates.ExecuteTemplate(w, "error.html", data) - } -} - -func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { - director := proxy.Director - proxy.Director = func(req *http.Request) { - director(req) - // use RequestURI so that we aren't unescaping encoded slashes in the request path - req.Host = target.Host - req.URL.Opaque = req.RequestURI - req.URL.RawQuery = "" - } -} - -func setProxyDirector(proxy *httputil.ReverseProxy) { - director := proxy.Director - proxy.Director = func(req *http.Request) { - director(req) - // use RequestURI so that we aren't unescaping encoded slashes in the request path - req.URL.Opaque = req.RequestURI - req.URL.RawQuery = "" - } -} - -// NewFileServer creates a http.Handler to serve files from the filesystem -func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { - return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) -} - -// NewWebSocketOrRestReverseProxy creates a reverse proxy for REST or websocket based on url -func NewWebSocketOrRestReverseProxy(u *url.URL, opts *options.Options, auth hmacauth.HmacAuth) http.Handler { - u.Path = "" - proxy := NewReverseProxy(u, opts) - if !opts.PassHostHeader { - setProxyUpstreamHostHeader(proxy, u) - } else { - setProxyDirector(proxy) - } - - // this should give us a wss:// scheme if the url is https:// based. - var wsProxy *wsutil.ReverseProxy - if opts.ProxyWebSockets { - wsScheme := "ws" + strings.TrimPrefix(u.Scheme, "http") - wsURL := &url.URL{Scheme: wsScheme, Host: u.Host} - wsProxy = wsutil.NewSingleHostReverseProxy(wsURL) - if opts.SSLUpstreamInsecureSkipVerify { - wsProxy.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - } - return &UpstreamProxy{ - upstream: u.Host, - handler: proxy, - wsHandler: wsProxy, - auth: auth, - } -} - // NewOAuthProxy creates a new instance of OAuthProxy from the options provided func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie) @@ -241,48 +107,13 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, fmt.Errorf("error initialising session store: %v", err) } - serveMux := http.NewServeMux() - var auth hmacauth.HmacAuth - if sigData := opts.GetSignatureData(); sigData != nil { - auth = hmacauth.NewHmacAuth(sigData.Hash, []byte(sigData.Key), - SignatureHeader, SignatureHeaders) + templates := loadTemplates(opts.CustomTemplatesDir) + proxyErrorHandler := upstream.NewProxyErrorHandler(templates.Lookup("error.html"), opts.ProxyPrefix) + upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), proxyErrorHandler) + if err != nil { + return nil, fmt.Errorf("error initialising upstream proxy: %v", err) } - for _, u := range opts.GetProxyURLs() { - path := u.Path - host := u.Host - switch u.Scheme { - case httpScheme, httpsScheme: - logger.Printf("mapping path %q => upstream %q", path, u) - proxy := NewWebSocketOrRestReverseProxy(u, opts, auth) - serveMux.Handle(path, proxy) - case "static": - responseCode, err := strconv.Atoi(host) - if err != nil { - logger.Printf("unable to convert %q to int, use default \"200\"", host) - responseCode = 200 - } - serveMux.HandleFunc(path, func(rw http.ResponseWriter, req *http.Request) { - rw.WriteHeader(responseCode) - fmt.Fprintf(rw, "Authenticated") - }) - case "file": - if u.Fragment != "" { - path = u.Fragment - } - logger.Printf("mapping path %q => file system %q", path, u.Path) - proxy := NewFileServer(path, u.Path) - uProxy := UpstreamProxy{ - upstream: path, - handler: proxy, - wsHandler: nil, - auth: nil, - } - serveMux.Handle(path, &uProxy) - default: - panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) - } - } for _, u := range opts.GetCompiledRegex() { logger.Printf("compiled skip-auth-regex => %q", u) } @@ -350,7 +181,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr provider: opts.GetProvider(), providerNameOverride: opts.ProviderName, sessionStore: sessionStore, - serveMux: serveMux, + serveMux: upstreamProxy, redirectURL: redirectURL, whitelistDomains: opts.WhitelistDomains, skipAuthRegex: opts.SkipAuthRegex, @@ -371,7 +202,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr PassAuthorization: opts.PassAuthorization, PreferEmailToUser: opts.PreferEmailToUser, SkipProviderButton: opts.SkipProviderButton, - templates: loadTemplates(opts.CustomTemplatesDir), + templates: templates, trustedIPs: trustedIPs, Banner: opts.Banner, Footer: opts.Footer, diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 40b2d9d1..00c33ff6 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "io/ioutil" - "net" "net/http" "net/http/httptest" "net/url" @@ -24,11 +23,11 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie" + "github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" "github.com/oauth2-proxy/oauth2-proxy/pkg/validation" "github.com/oauth2-proxy/oauth2-proxy/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/net/websocket" ) const ( @@ -44,143 +43,6 @@ func init() { logger.SetFlags(logger.Lshortfile) } -type WebSocketOrRestHandler struct { - restHandler http.Handler - wsHandler http.Handler -} - -func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") == "websocket" { - h.wsHandler.ServeHTTP(w, r) - } else { - h.restHandler.ServeHTTP(w, r) - } -} - -func TestWebSocketProxy(t *testing.T) { - handler := WebSocketOrRestHandler{ - restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - hostname, _, _ := net.SplitHostPort(r.Host) - _, err := w.Write([]byte(hostname)) - if err != nil { - t.Fatal(err) - } - }), - wsHandler: websocket.Handler(func(ws *websocket.Conn) { - defer func(t *testing.T) { - if err := ws.Close(); err != nil { - t.Fatal(err) - } - }(t) - var data []byte - err := websocket.Message.Receive(ws, &data) - if err != nil { - t.Fatal(err) - } - err = websocket.Message.Send(ws, data) - if err != nil { - t.Fatal(err) - } - }), - } - backend := httptest.NewServer(&handler) - t.Cleanup(backend.Close) - - backendURL, _ := url.Parse(backend.URL) - - opts := baseTestOptions() - var auth hmacauth.HmacAuth - opts.PassHostHeader = true - proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth) - frontend := httptest.NewServer(proxyHandler) - t.Cleanup(frontend.Close) - - frontendURL, _ := url.Parse(frontend.URL) - frontendWSURL := "ws://" + frontendURL.Host + "/" - - ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/") - if err != nil { - t.Fatal(err) - } - request := []byte("hello, world!") - err = websocket.Message.Send(ws, request) - if err != nil { - t.Fatal(err) - } - var response = make([]byte, 1024) - err = websocket.Message.Receive(ws, &response) - if err != nil { - t.Fatal(err) - } - if g, e := string(request), string(response); g != e { - t.Errorf("got body %q; expected %q", g, e) - } - - getReq, _ := http.NewRequest("GET", frontend.URL, nil) - res, _ := http.DefaultClient.Do(getReq) - bodyBytes, _ := ioutil.ReadAll(res.Body) - backendHostname, _, _ := net.SplitHostPort(backendURL.Host) - if g, e := string(bodyBytes), backendHostname; g != e { - t.Errorf("got body %q; expected %q", g, e) - } -} - -func TestNewReverseProxy(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - hostname, _, _ := net.SplitHostPort(r.Host) - _, err := w.Write([]byte(hostname)) - if err != nil { - t.Fatal(err) - } - })) - t.Cleanup(backend.Close) - - backendURL, _ := url.Parse(backend.URL) - backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) - backendHost := net.JoinHostPort(backendHostname, backendPort) - proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") - - proxyHandler := NewReverseProxy(proxyURL, &options.Options{FlushInterval: time.Second}) - setProxyUpstreamHostHeader(proxyHandler, proxyURL) - frontend := httptest.NewServer(proxyHandler) - t.Cleanup(frontend.Close) - - getReq, _ := http.NewRequest("GET", frontend.URL, nil) - res, _ := http.DefaultClient.Do(getReq) - bodyBytes, _ := ioutil.ReadAll(res.Body) - if g, e := string(bodyBytes), backendHostname; g != e { - t.Errorf("got body %q; expected %q", g, e) - } -} - -func TestEncodedSlashes(t *testing.T) { - var seen string - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - seen = r.RequestURI - })) - t.Cleanup(backend.Close) - - b, _ := url.Parse(backend.URL) - proxyHandler := NewReverseProxy(b, &options.Options{FlushInterval: time.Second}) - setProxyDirector(proxyHandler) - frontend := httptest.NewServer(proxyHandler) - t.Cleanup(frontend.Close) - - f, _ := url.Parse(frontend.URL) - encodedPath := "/a%2Fb/?c=1" - getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}} - _, err := http.DefaultClient.Do(getReq) - if err != nil { - t.Fatal(err) - } - if seen != encodedPath { - t.Errorf("got bad request %q expected %q", seen, encodedPath) - } -} - func TestRobotsTxt(t *testing.T) { opts := baseTestOptions() err := validation.Validate(opts) @@ -562,7 +424,14 @@ func TestBasicAuthPassword(t *testing.T) { } })) opts := baseTestOptions() - opts.Upstreams = append(opts.Upstreams, providerServer.URL) + opts.UpstreamServers = options.Upstreams{ + { + ID: providerServer.URL, + Path: "/", + URI: providerServer.URL, + }, + } + opts.Cookie.Secure = false opts.PassBasicAuth = true opts.SetBasicAuth = true @@ -867,7 +736,7 @@ type PassAccessTokenTest struct { type PassAccessTokenTestOptions struct { PassAccessToken bool - ProxyUpstream string + ProxyUpstream options.Upstream } func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTest, error) { @@ -893,10 +762,17 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe })) patt.opts = baseTestOptions() - patt.opts.Upstreams = append(patt.opts.Upstreams, patt.providerServer.URL) - if opts.ProxyUpstream != "" { - patt.opts.Upstreams = append(patt.opts.Upstreams, opts.ProxyUpstream) + patt.opts.UpstreamServers = options.Upstreams{ + { + ID: patt.providerServer.URL, + Path: "/", + URI: patt.providerServer.URL, + }, } + if opts.ProxyUpstream.ID != "" { + patt.opts.UpstreamServers = append(patt.opts.UpstreamServers, opts.ProxyUpstream) + } + patt.opts.Cookie.Secure = false patt.opts.PassAccessToken = opts.PassAccessToken err := validation.Validate(patt.opts) @@ -999,7 +875,11 @@ func TestForwardAccessTokenUpstream(t *testing.T) { func TestStaticProxyUpstream(t *testing.T) { patTest, err := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, - ProxyUpstream: "static://200/static-proxy", + ProxyUpstream: options.Upstream{ + ID: "static-proxy", + Path: "/static-proxy", + Static: true, + }, }) if err != nil { t.Fatal(err) @@ -1572,7 +1452,13 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { t.Cleanup(upstream.Close) opts := baseTestOptions() - opts.Upstreams = append(opts.Upstreams, upstream.URL) + opts.UpstreamServers = options.Upstreams{ + { + ID: upstream.URL, + Path: "/", + URI: upstream.URL, + }, + } opts.SkipAuthPreflight = true err := validation.Validate(opts) assert.NoError(t, err) @@ -1641,7 +1527,13 @@ func NewSignatureTest() (*SignatureTest, error) { if err != nil { return nil, err } - opts.Upstreams = append(opts.Upstreams, upstream.URL) + opts.UpstreamServers = options.Upstreams{ + { + ID: upstream.URL, + Path: "/", + URI: upstream.URL, + }, + } providerHandler := func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte(`{"access_token": "my_auth_token"}`)) @@ -1716,7 +1608,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er } // This is used by the upstream to validate the signature. st.authenticator.auth = hmacauth.NewHmacAuth( - crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) + crypto.SHA1, []byte(key), upstream.SignatureHeader, upstream.SignatureHeaders) proxy.ServeHTTP(st.rw, req) return nil @@ -2110,7 +2002,13 @@ func Test_noCacheHeaders(t *testing.T) { t.Cleanup(upstream.Close) opts := baseTestOptions() - opts.Upstreams = []string{upstream.URL} + opts.UpstreamServers = options.Upstreams{ + { + ID: upstream.URL, + Path: "/", + URI: upstream.URL, + }, + } opts.SkipAuthRegex = []string{".*"} err := validation.Validate(opts) assert.NoError(t, err) @@ -2335,7 +2233,13 @@ func TestTrustedIPs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := baseTestOptions() - opts.Upstreams = []string{"static://200"} + opts.UpstreamServers = options.Upstreams{ + { + ID: "static", + Path: "/", + Static: true, + }, + } opts.TrustedIPs = tt.trustedIPs opts.ReverseProxy = tt.reverseProxy opts.RealClientIPHeader = tt.realClientIPHeader From 71dc70222b705acf687bec97c1469cc363d3329f Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 12 Jul 2020 16:47:25 +0100 Subject: [PATCH 3/5] Break legacy upstream options into LegacyUpstreams struct --- pkg/apis/options/legacy_options.go | 49 +++++++++++++++++-------- pkg/apis/options/legacy_options_test.go | 20 +++++++--- pkg/apis/options/options.go | 7 +--- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index b0198c9c..6d7c7185 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -7,31 +7,30 @@ import ( "time" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/spf13/pflag" ) type LegacyOptions struct { // Legacy options related to upstream servers - LegacyFlushInterval time.Duration `flag:"flush-interval" cfg:"flush_interval"` - LegacyPassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` - LegacyProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets"` - LegacySSLUpstreamInsecureSkipVerify bool `flag:"ssl-upstream-insecure-skip-verify" cfg:"ssl_upstream_insecure_skip_verify"` - LegacyUpstreams []string `flag:"upstream" cfg:"upstreams"` + LegacyUpstreams LegacyUpstreams `cfg:",squash"` Options Options `cfg:",squash"` } func NewLegacyOptions() *LegacyOptions { return &LegacyOptions{ - LegacyPassHostHeader: true, - LegacyProxyWebSockets: true, - LegacyFlushInterval: time.Duration(1) * time.Second, + LegacyUpstreams: LegacyUpstreams{ + PassHostHeader: true, + ProxyWebSockets: true, + FlushInterval: time.Duration(1) * time.Second, + }, Options: *NewOptions(), } } func (l *LegacyOptions) ToOptions() (*Options, error) { - upstreams, err := convertLegacyUpstreams(l.LegacyUpstreams, l.LegacySSLUpstreamInsecureSkipVerify, l.LegacyPassHostHeader, l.LegacyProxyWebSockets, l.LegacyFlushInterval) + upstreams, err := l.LegacyUpstreams.convert() if err != nil { return nil, fmt.Errorf("error converting upstreams: %v", err) } @@ -40,10 +39,30 @@ func (l *LegacyOptions) ToOptions() (*Options, error) { return &l.Options, nil } -func convertLegacyUpstreams(upstreamStrings []string, skipVerify, passHostHeader, proxyWebSockets bool, flushInterval time.Duration) (Upstreams, error) { +type LegacyUpstreams struct { + FlushInterval time.Duration `flag:"flush-interval" cfg:"flush_interval"` + PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` + ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets"` + SSLUpstreamInsecureSkipVerify bool `flag:"ssl-upstream-insecure-skip-verify" cfg:"ssl_upstream_insecure_skip_verify"` + Upstreams []string `flag:"upstream" cfg:"upstreams"` +} + +func legacyUpstreamsFlagSet() *pflag.FlagSet { + flagSet := pflag.NewFlagSet("upstreams", pflag.ExitOnError) + + flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses") + flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") + flagSet.Bool("proxy-websockets", true, "enables WebSocket proxying") + flagSet.Bool("ssl-upstream-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS upstreams") + flagSet.StringSlice("upstream", []string{}, "the http url(s) of the upstream endpoint, file:// paths for static files or static:// for static response. Routing is based on the path") + + return flagSet +} + +func (l *LegacyUpstreams) convert() (Upstreams, error) { upstreams := Upstreams{} - for _, upstreamString := range upstreamStrings { + for _, upstreamString := range l.Upstreams { u, err := url.Parse(upstreamString) if err != nil { return nil, fmt.Errorf("could not parse upstream %q: %v", upstreamString, err) @@ -57,10 +76,10 @@ func convertLegacyUpstreams(upstreamStrings []string, skipVerify, passHostHeader ID: u.Path, Path: u.Path, URI: upstreamString, - InsecureSkipTLSVerify: skipVerify, - PassHostHeader: passHostHeader, - ProxyWebSockets: proxyWebSockets, - FlushInterval: &flushInterval, + InsecureSkipTLSVerify: l.SSLUpstreamInsecureSkipVerify, + PassHostHeader: l.PassHostHeader, + ProxyWebSockets: l.ProxyWebSockets, + FlushInterval: &l.FlushInterval, } switch u.Scheme { diff --git a/pkg/apis/options/legacy_options_test.go b/pkg/apis/options/legacy_options_test.go index 0e221749..3046e4a1 100644 --- a/pkg/apis/options/legacy_options_test.go +++ b/pkg/apis/options/legacy_options_test.go @@ -17,11 +17,11 @@ var _ = Describe("Legacy Options", func() { // Set upstreams and related options to test their conversion flushInterval := 5 * time.Second - legacyOpts.LegacyFlushInterval = flushInterval - legacyOpts.LegacyPassHostHeader = true - legacyOpts.LegacyProxyWebSockets = true - legacyOpts.LegacySSLUpstreamInsecureSkipVerify = true - legacyOpts.LegacyUpstreams = []string{"http://foo.bar/baz", "file://var/lib/website#/bar"} + legacyOpts.LegacyUpstreams.FlushInterval = flushInterval + legacyOpts.LegacyUpstreams.PassHostHeader = true + legacyOpts.LegacyUpstreams.ProxyWebSockets = true + legacyOpts.LegacyUpstreams.SSLUpstreamInsecureSkipVerify = true + legacyOpts.LegacyUpstreams.Upstreams = []string{"http://foo.bar/baz", "file://var/lib/website#/bar"} opts.UpstreamServers = Upstreams{ { @@ -133,7 +133,15 @@ var _ = Describe("Legacy Options", func() { DescribeTable("convertLegacyUpstreams", func(o *convertUpstreamsTableInput) { - upstreams, err := convertLegacyUpstreams(o.upstreamStrings, skipVerify, passHostHeader, proxyWebSockets, flushInterval) + legacyUpstreams := LegacyUpstreams{ + Upstreams: o.upstreamStrings, + SSLUpstreamInsecureSkipVerify: skipVerify, + PassHostHeader: passHostHeader, + ProxyWebSockets: proxyWebSockets, + FlushInterval: flushInterval, + } + + upstreams, err := legacyUpstreams.convert() if o.errMsg != "" { Expect(err).To(HaveOccurred()) diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 325c263c..e37dc63b 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -4,7 +4,6 @@ import ( "crypto" "net/url" "regexp" - "time" oidc "github.com/coreos/go-oidc" ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" @@ -185,14 +184,12 @@ func NewFlagSet() *pflag.FlagSet { flagSet.String("tls-key-file", "", "path to private key file") flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") flagSet.Bool("set-xauthrequest", false, "set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)") - flagSet.StringSlice("upstream", []string{}, "the http url(s) of the upstream endpoint, file:// paths for static files or static:// for static response. Routing is based on the path") flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.Bool("set-basic-auth", false, "set HTTP Basic Auth information in response (useful in Nginx auth_request mode)") flagSet.Bool("prefer-email-to-user", false, "Prefer to use the Email address as the Username when passing information to upstream. Will only use Username if Email is unavailable, eg. htaccess authentication. Used in conjunction with -pass-basic-auth and -pass-user-headers") flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") - flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream") flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)") flagSet.StringSlice("skip-auth-regex", []string{}, "bypass authentication for requests path's that match (may be given multiple times)") @@ -200,8 +197,6 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS providers") - flagSet.Bool("ssl-upstream-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS upstreams") - flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses") flagSet.Bool("skip-jwt-bearer-tokens", false, "will skip requests that have verified JWT bearer tokens (default false)") flagSet.StringSlice("extra-jwt-issuers", []string{}, "if skip-jwt-bearer-tokens is set, a list of extra JWT issuer=audience pairs (where the issuer URL has a .well-known/openid-configuration or a .well-known/jwks.json)") @@ -232,7 +227,6 @@ func NewFlagSet() *pflag.FlagSet { flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. //sign_in)") flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") - flagSet.Bool("proxy-websockets", true, "enables WebSocket proxying") flagSet.String("session-store-type", "cookie", "the session storage provider to use") flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") @@ -272,6 +266,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(loggingFlagSet()) + flagSet.AddFlagSet(legacyUpstreamsFlagSet()) return flagSet } From 6b2706981213305a3136fc9b27a2556301ca7cd1 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Thu, 16 Jul 2020 22:44:00 +0100 Subject: [PATCH 4/5] Add changelog entry for integrating new upstream proxy --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad30c6ae..25f29079 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ## Changes since v6.0.0 +- [#593](https://github.com/oauth2-proxy/oauth2-proxy/pull/593) Integrate upstream package with OAuth2 Proxy (@JoelSpeed) - [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed) - [#624](https://github.com/oauth2-proxy/oauth2-proxy/pull/624) Allow stripping authentication headers from whitelisted requests with `--skip-auth-strip-headers` (@NickMeves) - [#673](https://github.com/oauth2-proxy/oauth2-proxy/pull/673) Add --session-cookie-minimal option to create session cookies with no tokens (@NickMeves) From d43b372ca9bb72a35fcdee72326de9374f9cb60c Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 19 Jul 2020 14:00:52 +0100 Subject: [PATCH 5/5] Use bool pointers for upstream options that default to true --- pkg/apis/options/legacy_options.go | 8 +++---- pkg/apis/options/legacy_options_test.go | 29 +++++++++++++------------ pkg/apis/options/upstreams.go | 4 ++-- pkg/upstream/http.go | 4 ++-- pkg/upstream/http_test.go | 17 +++++++++------ pkg/validation/upstreams.go | 4 ++-- pkg/validation/upstreams_test.go | 13 +++++------ 7 files changed, 41 insertions(+), 38 deletions(-) diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index 6d7c7185..762eba07 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -77,8 +77,8 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) { Path: u.Path, URI: upstreamString, InsecureSkipTLSVerify: l.SSLUpstreamInsecureSkipVerify, - PassHostHeader: l.PassHostHeader, - ProxyWebSockets: l.ProxyWebSockets, + PassHostHeader: &l.PassHostHeader, + ProxyWebSockets: &l.ProxyWebSockets, FlushInterval: &l.FlushInterval, } @@ -104,8 +104,8 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) { // Force defaults compatible with static responses upstream.URI = "" upstream.InsecureSkipTLSVerify = false - upstream.PassHostHeader = true - upstream.ProxyWebSockets = false + upstream.PassHostHeader = nil + upstream.ProxyWebSockets = nil flush := 1 * time.Second upstream.FlushInterval = &flush } diff --git a/pkg/apis/options/legacy_options_test.go b/pkg/apis/options/legacy_options_test.go index 3046e4a1..4223a293 100644 --- a/pkg/apis/options/legacy_options_test.go +++ b/pkg/apis/options/legacy_options_test.go @@ -23,6 +23,7 @@ var _ = Describe("Legacy Options", func() { legacyOpts.LegacyUpstreams.SSLUpstreamInsecureSkipVerify = true legacyOpts.LegacyUpstreams.Upstreams = []string{"http://foo.bar/baz", "file://var/lib/website#/bar"} + truth := true opts.UpstreamServers = Upstreams{ { ID: "/baz", @@ -30,8 +31,8 @@ var _ = Describe("Legacy Options", func() { URI: "http://foo.bar/baz", FlushInterval: &flushInterval, InsecureSkipTLSVerify: true, - PassHostHeader: true, - ProxyWebSockets: true, + PassHostHeader: &truth, + ProxyWebSockets: &truth, }, { ID: "/bar", @@ -39,8 +40,8 @@ var _ = Describe("Legacy Options", func() { URI: "file://var/lib/website#/bar", FlushInterval: &flushInterval, InsecureSkipTLSVerify: true, - PassHostHeader: true, - ProxyWebSockets: true, + PassHostHeader: &truth, + ProxyWebSockets: &truth, }, } @@ -72,8 +73,8 @@ var _ = Describe("Legacy Options", func() { Path: "/baz", URI: validHTTP, InsecureSkipTLSVerify: skipVerify, - PassHostHeader: passHostHeader, - ProxyWebSockets: proxyWebSockets, + PassHostHeader: &passHostHeader, + ProxyWebSockets: &proxyWebSockets, FlushInterval: &flushInterval, } @@ -84,8 +85,8 @@ var _ = Describe("Legacy Options", func() { Path: "/", URI: emptyPathHTTP, InsecureSkipTLSVerify: skipVerify, - PassHostHeader: passHostHeader, - ProxyWebSockets: proxyWebSockets, + PassHostHeader: &passHostHeader, + ProxyWebSockets: &proxyWebSockets, FlushInterval: &flushInterval, } @@ -95,8 +96,8 @@ var _ = Describe("Legacy Options", func() { Path: "/bar", URI: validFileWithFragment, InsecureSkipTLSVerify: skipVerify, - PassHostHeader: passHostHeader, - ProxyWebSockets: proxyWebSockets, + PassHostHeader: &passHostHeader, + ProxyWebSockets: &proxyWebSockets, FlushInterval: &flushInterval, } @@ -109,8 +110,8 @@ var _ = Describe("Legacy Options", func() { Static: true, StaticCode: &validStaticCode, InsecureSkipTLSVerify: false, - PassHostHeader: true, - ProxyWebSockets: false, + PassHostHeader: nil, + ProxyWebSockets: nil, FlushInterval: &defaultFlushInterval, } @@ -123,8 +124,8 @@ var _ = Describe("Legacy Options", func() { Static: true, StaticCode: &invalidStaticCode, InsecureSkipTLSVerify: false, - PassHostHeader: true, - ProxyWebSockets: false, + PassHostHeader: nil, + ProxyWebSockets: nil, FlushInterval: &defaultFlushInterval, } diff --git a/pkg/apis/options/upstreams.go b/pkg/apis/options/upstreams.go index 5a8eebe5..e879a107 100644 --- a/pkg/apis/options/upstreams.go +++ b/pkg/apis/options/upstreams.go @@ -52,9 +52,9 @@ type Upstream struct { // PassHostHeader determines whether the request host header should be proxied // to the upstream server. // Defaults to true. - PassHostHeader bool `json:"passHostHeader"` + PassHostHeader *bool `json:"passHostHeader"` // ProxyWebSockets enables proxying of websockets to upstream servers // Defaults to true. - ProxyWebSockets bool `json:"proxyWebSockets"` + ProxyWebSockets *bool `json:"proxyWebSockets"` } diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go index fa7b2a8a..5011b775 100644 --- a/pkg/upstream/http.go +++ b/pkg/upstream/http.go @@ -49,7 +49,7 @@ func newHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *option // Set up a WebSocket proxy if required var wsProxy http.Handler - if upstream.ProxyWebSockets { + if upstream.ProxyWebSockets == nil || *upstream.ProxyWebSockets { wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify) } @@ -110,7 +110,7 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr } // Set the request director based on the PassHostHeader option - if !upstream.PassHostHeader { + if upstream.PassHostHeader != nil && !*upstream.PassHostHeader { setProxyUpstreamHostHeader(proxy, target) } else { setProxyDirector(proxy) diff --git a/pkg/upstream/http_test.go b/pkg/upstream/http_test.go index 8c601880..c49c0c35 100644 --- a/pkg/upstream/http_test.go +++ b/pkg/upstream/http_test.go @@ -24,6 +24,8 @@ var _ = Describe("HTTP Upstream Suite", func() { const flushInterval5s = 5 * time.Second const flushInterval1s = 1 * time.Second + truth := true + falsum := false type httpUpstreamTableInput struct { id string @@ -51,10 +53,11 @@ var _ = Describe("HTTP Upstream Suite", func() { rw := httptest.NewRecorder() flush := 1 * time.Second + upstream := options.Upstream{ ID: in.id, - PassHostHeader: true, - ProxyWebSockets: false, + PassHostHeader: &truth, + ProxyWebSockets: &falsum, InsecureSkipTLSVerify: false, FlushInterval: &flush, } @@ -258,8 +261,8 @@ var _ = Describe("HTTP Upstream Suite", func() { flush := 1 * time.Second upstream := options.Upstream{ ID: "noPassHost", - PassHostHeader: false, - ProxyWebSockets: false, + PassHostHeader: &falsum, + ProxyWebSockets: &falsum, InsecureSkipTLSVerify: false, FlushInterval: &flush, } @@ -302,7 +305,7 @@ var _ = Describe("HTTP Upstream Suite", func() { ID: "foo123", FlushInterval: &in.flushInterval, InsecureSkipTLSVerify: in.skipVerify, - ProxyWebSockets: in.proxyWebSockets, + ProxyWebSockets: &in.proxyWebSockets, } handler := newHTTPUpstreamProxy(upstream, u, in.sigData, in.errorHandler) @@ -370,8 +373,8 @@ var _ = Describe("HTTP Upstream Suite", func() { flush := 1 * time.Second upstream := options.Upstream{ ID: "websocketProxy", - PassHostHeader: true, - ProxyWebSockets: true, + PassHostHeader: &truth, + ProxyWebSockets: &truth, InsecureSkipTLSVerify: false, FlushInterval: &flush, } diff --git a/pkg/validation/upstreams.go b/pkg/validation/upstreams.go index be9c0ea7..2b491a5d 100644 --- a/pkg/validation/upstreams.go +++ b/pkg/validation/upstreams.go @@ -73,10 +73,10 @@ func validateStaticUpstream(upstream options.Upstream) []string { if upstream.FlushInterval != nil && *upstream.FlushInterval != time.Second { msgs = append(msgs, fmt.Sprintf("upstream %q has flushInterval, but is a static upstream, this will have no effect.", upstream.ID)) } - if !upstream.PassHostHeader { + if upstream.PassHostHeader != nil { msgs = append(msgs, fmt.Sprintf("upstream %q has passHostHeader, but is a static upstream, this will have no effect.", upstream.ID)) } - if !upstream.ProxyWebSockets { + if upstream.ProxyWebSockets != nil { msgs = append(msgs, fmt.Sprintf("upstream %q has proxyWebSockets, but is a static upstream, this will have no effect.", upstream.ID)) } diff --git a/pkg/validation/upstreams_test.go b/pkg/validation/upstreams_test.go index 4995bf29..86f3da66 100644 --- a/pkg/validation/upstreams_test.go +++ b/pkg/validation/upstreams_test.go @@ -17,6 +17,7 @@ var _ = Describe("Upstreams", func() { flushInterval := 5 * time.Second staticCode200 := 200 + truth := true validHTTPUpstream := options.Upstream{ ID: "validHTTPUpstream", @@ -24,11 +25,9 @@ var _ = Describe("Upstreams", func() { URI: "http://localhost:8080", } validStaticUpstream := options.Upstream{ - ID: "validStaticUpstream", - Path: "/validStaticUpstream", - Static: true, - PassHostHeader: true, // This would normally be defaulted - ProxyWebSockets: true, // this would normally be defaulted + ID: "validStaticUpstream", + Path: "/validStaticUpstream", + Static: true, } validFileUpstream := options.Upstream{ ID: "validFileUpstream", @@ -134,8 +133,8 @@ var _ = Describe("Upstreams", func() { URI: "ftp://foo", Static: true, FlushInterval: &flushInterval, - PassHostHeader: false, - ProxyWebSockets: false, + PassHostHeader: &truth, + ProxyWebSockets: &truth, InsecureSkipTLSVerify: true, }, },