diff --git a/CHANGELOG.md b/CHANGELOG.md index 507a577a..e1a61d59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ ## Changes since v6.0.0 +- [#577](https://github.com/oauth2-proxy/oauth2-proxy/pull/577) Move Cipher and Session Store initialisation out of Validation (@JoelSpeed) + # v6.0.0 ## Release Highlights diff --git a/go.sum b/go.sum index 7e92e512..b9955838 100644 --- a/go.sum +++ b/go.sum @@ -202,6 +202,7 @@ go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -225,6 +226,7 @@ golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= diff --git a/main.go b/main.go index 9b04e24d..0720ba86 100644 --- a/main.go +++ b/main.go @@ -45,7 +45,11 @@ func main() { } validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) - oauthproxy := NewOAuthProxy(opts, validator) + oauthproxy, err := NewOAuthProxy(opts, validator) + if err != nil { + logger.Printf("ERROR: Failed to initialise OAuth2 Proxy: %v", err) + os.Exit(1) + } if len(opts.Banner) >= 1 { if opts.Banner == "-" { diff --git a/oauthproxy.go b/oauthproxy.go index d5e71844..f09c97d8 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -26,6 +26,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" + "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/providers" "github.com/yhat/wsutil" ) @@ -231,7 +232,12 @@ func NewWebSocketOrRestReverseProxy(u *url.URL, opts *options.Options, auth hmac } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided -func NewOAuthProxy(opts *options.Options, validator func(string) bool) *OAuthProxy { +func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { + sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie) + if err != nil { + return nil, fmt.Errorf("error initialising session store: %v", err) + } + serveMux := http.NewServeMux() var auth hmacauth.HmacAuth if sigData := opts.GetSignatureData(); sigData != nil { @@ -321,7 +327,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) *OAuthPro ProxyPrefix: opts.ProxyPrefix, provider: opts.GetProvider(), providerNameOverride: opts.ProviderName, - sessionStore: opts.GetSessionStore(), + sessionStore: sessionStore, serveMux: serveMux, redirectURL: redirectURL, whitelistDomains: opts.WhitelistDomains, @@ -345,7 +351,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) *OAuthPro templates: loadTemplates(opts.CustomTemplatesDir), Banner: opts.Banner, Footer: opts.Footer, - } + }, nil } // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 27b74795..cd7e8e53 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -35,7 +35,7 @@ const ( // The rawCookieSecret is 32 bytes and the base64CookieSecret is the base64 // encoded version of this. rawCookieSecret = "secretthirtytwobytes+abcdefghijk" - base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpamsK" + base64CookieSecret = "c2VjcmV0dGhpcnR5dHdvYnl0ZXMrYWJjZGVmZ2hpams" ) func init() { @@ -82,10 +82,10 @@ func TestWebSocketProxy(t *testing.T) { backendURL, _ := url.Parse(backend.URL) - options := options.NewOptions() + opts := baseTestOptions() var auth hmacauth.HmacAuth - options.PassHostHeader = true - proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, options, auth) + opts.PassHostHeader = true + proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, opts, auth) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() @@ -172,13 +172,14 @@ func TestEncodedSlashes(t *testing.T) { } func TestRobotsTxt(t *testing.T) { - opts := options.NewOptions() + opts := baseTestOptions() opts.ClientID = "asdlkjx" opts.ClientSecret = "alkgks" opts.Cookie.Secret = rawCookieSecret validation.Validate(opts) - proxy := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + assert.NoError(t, err) rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/robots.txt", nil) proxy.ServeHTTP(rw, req) @@ -187,7 +188,7 @@ func TestRobotsTxt(t *testing.T) { } func TestIsValidRedirect(t *testing.T) { - opts := options.NewOptions() + opts := baseTestOptions() opts.ClientID = "skdlfj" opts.ClientSecret = "fgkdsgj" opts.Cookie.Secret = base64CookieSecret @@ -202,7 +203,8 @@ func TestIsValidRedirect(t *testing.T) { } validation.Validate(opts) - proxy := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + assert.NoError(t, err) testCases := []struct { Desc, Redirect string @@ -453,11 +455,10 @@ func TestOpenRedirects(t *testing.T) { "www.whitelisteddomain.tld", } err := validation.Validate(opts) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - proxy := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + assert.NoError(t, err) file, err := os.Open("./test/openredirects.txt") if err != nil { @@ -545,7 +546,7 @@ func TestBasicAuthPassword(t *testing.T) { w.WriteHeader(200) w.Write([]byte(payload)) })) - opts := options.NewOptions() + opts := baseTestOptions() opts.Upstreams = append(opts.Upstreams, providerServer.URL) // The CookieSecret must be 32 bytes in order to create the AES // cipher. @@ -564,9 +565,10 @@ func TestBasicAuthPassword(t *testing.T) { const emailAddress = "john.doe@example.com" opts.SetProvider(NewTestProvider(providerURL, emailAddress)) - proxy := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) + assert.NoError(t, err) rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) @@ -611,11 +613,12 @@ func TestBasicAuthPassword(t *testing.T) { } func TestBasicAuthWithEmail(t *testing.T) { - opts := options.NewOptions() + opts := baseTestOptions() opts.PassBasicAuth = true opts.PassUserHeaders = false opts.PreferEmailToUser = false opts.BasicAuthPassword = "This is a secure password" + opts.Cookie.Secret = rawCookieSecret validation.Validate(opts) const emailAddress = "john.doe@example.com" @@ -635,9 +638,10 @@ func TestBasicAuthWithEmail(t *testing.T) { { rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil) - proxy := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) + assert.NoError(t, err) proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, expectedUserHeader, req.Header["Authorization"][0]) assert.Equal(t, userName, req.Header["X-Forwarded-User"][0]) @@ -648,9 +652,10 @@ func TestBasicAuthWithEmail(t *testing.T) { rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase1", nil) - proxy := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) + assert.NoError(t, err) proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, expectedEmailHeader, req.Header["Authorization"][0]) assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0]) @@ -658,10 +663,11 @@ func TestBasicAuthWithEmail(t *testing.T) { } func TestPassUserHeadersWithEmail(t *testing.T) { - opts := options.NewOptions() + opts := baseTestOptions() opts.PassBasicAuth = false opts.PassUserHeaders = true opts.PreferEmailToUser = false + opts.Cookie.Secret = base64CookieSecret validation.Validate(opts) const emailAddress = "john.doe@example.com" @@ -677,9 +683,10 @@ func TestPassUserHeadersWithEmail(t *testing.T) { { rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil) - proxy := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) + assert.NoError(t, err) proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, userName, req.Header["X-Forwarded-User"][0]) } @@ -689,9 +696,10 @@ func TestPassUserHeadersWithEmail(t *testing.T) { rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase1", nil) - proxy := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) + assert.NoError(t, err) proxy.addHeadersForProxying(rw, req, session) assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0]) } @@ -727,7 +735,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes w.Write([]byte(payload)) })) - t.opts = options.NewOptions() + t.opts = baseTestOptions() t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL) if opts.ProxyUpstream != "" { t.opts.Upstreams = append(t.opts.Upstreams, opts.ProxyUpstream) @@ -745,9 +753,13 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes const emailAddress = "michael.bland@gsa.gov" t.opts.SetProvider(NewTestProvider(providerURL, emailAddress)) - t.proxy = NewOAuthProxy(t.opts, func(email string) bool { + var err error + t.proxy, err = NewOAuthProxy(t.opts, func(email string) bool { return email == emailAddress }) + if err != nil { + panic(err) + } return t } @@ -886,16 +898,20 @@ const signInSkipProvider = `>Found<` func NewSignInPageTest(skipProvider bool) *SignInPageTest { var sipTest SignInPageTest - sipTest.opts = options.NewOptions() + sipTest.opts = baseTestOptions() sipTest.opts.Cookie.Secret = rawCookieSecret sipTest.opts.ClientID = "lkdgj" sipTest.opts.ClientSecret = "sgiufgoi" sipTest.opts.SkipProviderButton = skipProvider validation.Validate(sipTest.opts) - sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool { + var err error + sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool { return true }) + if err != nil { + panic(err) + } sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern) sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider) @@ -987,7 +1003,7 @@ type OptionsModifier func(*options.Options) func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest { var pcTest ProcessCookieTest - pcTest.opts = options.NewOptions() + pcTest.opts = baseTestOptions() for _, modifier := range modifiers { modifier(pcTest.opts) } @@ -999,9 +1015,13 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi pcTest.opts.Cookie.Refresh = time.Hour validation.Validate(pcTest.opts) - pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { + var err error + pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) + if err != nil { + panic(err) + } pcTest.proxy.provider = &TestProvider{ ValidToken: opts.providerValidateCookieResponse, } @@ -1201,13 +1221,19 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest - pcTest.opts = options.NewOptions() + pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true - validation.Validate(pcTest.opts) + err := validation.Validate(pcTest.opts) + if err != nil { + panic(err) + } - pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) + if err != nil { + panic(err) + } pcTest.proxy.provider = &TestProvider{ ValidToken: true, } @@ -1232,14 +1258,18 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest - pcTest.opts = options.NewOptions() + pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true pcTest.opts.SetBasicAuth = true validation.Validate(pcTest.opts) - pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { + var err error + pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) + if err != nil { + panic(err) + } pcTest.proxy.provider = &TestProvider{ ValidToken: true, } @@ -1266,14 +1296,18 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest - pcTest.opts = options.NewOptions() + pcTest.opts = baseTestOptions() pcTest.opts.SetXAuthRequest = true pcTest.opts.SetBasicAuth = false validation.Validate(pcTest.opts) - pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { + var err error + pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { return pcTest.validateUser }) + if err != nil { + panic(err) + } pcTest.proxy.provider = &TestProvider{ ValidToken: true, } @@ -1303,18 +1337,16 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { })) defer upstream.Close() - opts := options.NewOptions() + opts := baseTestOptions() opts.Upstreams = append(opts.Upstreams, upstream.URL) - opts.ClientID = "aljsal" - opts.ClientSecret = "jglkfsdgj" - opts.Cookie.Secret = base64CookieSecret opts.SkipAuthPreflight = true validation.Validate(opts) upstreamURL, _ := url.Parse(upstream.URL) opts.SetProvider(NewTestProvider(upstreamURL, "")) - proxy := NewOAuthProxy(opts, func(string) bool { return false }) + proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) + assert.NoError(t, err) rw := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) proxy.ServeHTTP(rw, req) @@ -1353,7 +1385,7 @@ type SignatureTest struct { } func NewSignatureTest() *SignatureTest { - opts := options.NewOptions() + opts := baseTestOptions() opts.Cookie.Secret = rawCookieSecret opts.ClientID = "client ID" opts.ClientSecret = "client secret" @@ -1409,7 +1441,10 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { if err != nil { panic(err) } - proxy := NewOAuthProxy(st.opts, func(email string) bool { return true }) + proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true }) + if err != nil { + panic(err) + } var bodyBuf io.ReadCloser if body != "" { @@ -1461,10 +1496,12 @@ func TestRequestSignaturePostRequest(t *testing.T) { } func TestGetRedirect(t *testing.T) { - options := options.NewOptions() - _ = validation.Validate(options) - require.NotEmpty(t, options.ProxyPrefix) - proxy := NewOAuthProxy(options, func(s string) bool { return false }) + opts := baseTestOptions() + err := validation.Validate(opts) + assert.NoError(t, err) + require.NotEmpty(t, opts.ProxyPrefix) + proxy, err := NewOAuthProxy(opts, func(s string) bool { return false }) + assert.NoError(t, err) tests := []struct { name string @@ -1500,14 +1537,19 @@ type ajaxRequestTest struct { func newAjaxRequestTest() *ajaxRequestTest { test := &ajaxRequestTest{} - test.opts = options.NewOptions() + test.opts = baseTestOptions() test.opts.Cookie.Secret = base64CookieSecret test.opts.ClientID = "gkljfdl" test.opts.ClientSecret = "sdflkjs" validation.Validate(test.opts) - test.proxy = NewOAuthProxy(test.opts, func(email string) bool { + + var err error + test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool { return true }) + if err != nil { + panic(err) + } return test } @@ -1558,11 +1600,12 @@ func TestAjaxForbiddendRequest(t *testing.T) { } func TestClearSplitCookie(t *testing.T) { - opts := options.NewOptions() + opts := baseTestOptions() + opts.Cookie.Secret = base64CookieSecret opts.Cookie.Name = "oauth2" opts.Cookie.Domains = []string{"abc"} store, err := cookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1587,11 +1630,11 @@ func TestClearSplitCookie(t *testing.T) { } func TestClearSingleCookie(t *testing.T) { - opts := options.NewOptions() + opts := baseTestOptions() opts.Cookie.Name = "oauth2" opts.Cookie.Domains = []string{"abc"} store, err := cookie.NewCookieSessionStore(&opts.Session, &opts.Cookie) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1768,13 +1811,14 @@ func Test_noCacheHeadersDoesNotExistsInResponseHeadersFromUpstream(t *testing.T) })) t.Cleanup(upstream.Close) - opts := options.NewOptions() + opts := baseTestOptions() opts.Upstreams = []string{upstream.URL} opts.SkipAuthRegex = []string{".*"} _ = validation.Validate(opts) - proxy := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(opts, func(email string) bool { return true }) + assert.NoError(t, err) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/upstream", nil) @@ -1788,3 +1832,12 @@ func Test_noCacheHeadersDoesNotExistsInResponseHeadersFromUpstream(t *testing.T) assert.Equal(t, "", rec.Header().Get(k)) } } + +func baseTestOptions() *options.Options { + opts := options.NewOptions() + opts.Cookie.Secret = rawCookieSecret + opts.ClientID = "cliend-id" + opts.ClientSecret = "client-secret" + opts.EmailDomains = []string{"*"} + return opts +} diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index ff9b1ca1..65bdaacc 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -8,7 +8,6 @@ import ( oidc "github.com/coreos/go-oidc" ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" - sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/providers" "github.com/spf13/pflag" ) @@ -115,7 +114,6 @@ type Options struct { proxyURLs []*url.URL compiledRegex []*regexp.Regexp provider providers.Provider - sessionStore sessionsapi.SessionStore signatureData *SignatureData oidcVerifier *oidc.IDTokenVerifier jwtBearerVerifiers []*oidc.IDTokenVerifier @@ -127,7 +125,6 @@ func (o *Options) GetRedirectURL() *url.URL { return o.re func (o *Options) GetProxyURLs() []*url.URL { return o.proxyURLs } func (o *Options) GetCompiledRegex() []*regexp.Regexp { return o.compiledRegex } func (o *Options) GetProvider() providers.Provider { return o.provider } -func (o *Options) GetSessionStore() sessionsapi.SessionStore { return o.sessionStore } func (o *Options) GetSignatureData() *SignatureData { return o.signatureData } func (o *Options) GetOIDCVerifier() *oidc.IDTokenVerifier { return o.oidcVerifier } func (o *Options) GetJWTBearerVerifiers() []*oidc.IDTokenVerifier { return o.jwtBearerVerifiers } @@ -138,7 +135,6 @@ func (o *Options) SetRedirectURL(s *url.URL) { o.redirect func (o *Options) SetProxyURLs(s []*url.URL) { o.proxyURLs = s } func (o *Options) SetCompiledRegex(s []*regexp.Regexp) { o.compiledRegex = s } func (o *Options) SetProvider(s providers.Provider) { o.provider = s } -func (o *Options) SetSessionStore(s sessionsapi.SessionStore) { o.sessionStore = s } func (o *Options) SetSignatureData(s *SignatureData) { o.signatureData = s } func (o *Options) SetOIDCVerifier(s *oidc.IDTokenVerifier) { o.oidcVerifier = s } func (o *Options) SetJWTBearerVerifiers(s []*oidc.IDTokenVerifier) { o.jwtBearerVerifiers = s } diff --git a/pkg/apis/options/sessions.go b/pkg/apis/options/sessions.go index 3b2a4d19..7ff656fe 100644 --- a/pkg/apis/options/sessions.go +++ b/pkg/apis/options/sessions.go @@ -1,12 +1,9 @@ package options -import "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" - // SessionOptions contains configuration options for the SessionStore providers. type SessionOptions struct { - Type string `flag:"session-store-type" cfg:"session_store_type"` - Cipher encryption.Cipher `cfg:",internal"` - Redis RedisStoreOptions `cfg:",squash"` + Type string `flag:"session-store-type" cfg:"session_store_type"` + Redis RedisStoreOptions `cfg:",squash"` } // CookieSessionStoreType is used to indicate the CookieSessionStore should be diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 62f6b348..00fb02b4 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -127,8 +127,13 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string, // NewCookieSessionStore initialises a new instance of the SessionStore from // the configuration given func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { + cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) + if err != nil { + return nil, fmt.Errorf("error initialising cipher: %v", err) + } + return &SessionStore{ - CookieCipher: opts.Cipher, + CookieCipher: cipher, CookieOptions: cookieOpts, }, nil } diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index 2d0fd9b0..2b79aebc 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -40,6 +40,11 @@ type SessionStore struct { // NewRedisSessionStore initialises a new instance of the SessionStore from // the configuration given func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { + cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) + if err != nil { + return nil, fmt.Errorf("error initialising cipher: %v", err) + } + client, err := newRedisCmdable(opts.Redis) if err != nil { return nil, fmt.Errorf("error constructing redis client: %v", err) @@ -47,7 +52,7 @@ func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cook rs := &SessionStore{ Client: client, - CookieCipher: opts.Cipher, + CookieCipher: cipher, CookieOptions: cookieOpts, } return rs, nil diff --git a/pkg/sessions/redis/redis_store_test.go b/pkg/sessions/redis/redis_store_test.go index e7b801dd..e4bd7919 100644 --- a/pkg/sessions/redis/redis_store_test.go +++ b/pkg/sessions/redis/redis_store_test.go @@ -1,6 +1,7 @@ package redis import ( + "crypto/rand" "net/http" "net/http/httptest" "net/url" @@ -15,6 +16,10 @@ import ( ) func TestRedisStore(t *testing.T) { + secret := make([]byte, 32) + _, err := rand.Read(secret) + assert.NoError(t, err) + t.Run("save session on redis standalone", func(t *testing.T) { redisServer, err := miniredis.Run() require.NoError(t, err) @@ -25,6 +30,8 @@ func TestRedisStore(t *testing.T) { Host: redisServer.Addr(), } opts.Session.Redis.ConnectionURL = redisURL.String() + + opts.Cookie.Secret = string(secret) redisStore, err := NewRedisSessionStore(&opts.Session, &opts.Cookie) require.NoError(t, err) err = redisStore.Save( @@ -49,6 +56,8 @@ func TestRedisStore(t *testing.T) { opts.Session.Redis.SentinelConnectionURLs = []string{sentinelURL.String()} opts.Session.Redis.UseSentinel = true opts.Session.Redis.SentinelMasterName = sentinel.MasterInfo().Name + + opts.Cookie.Secret = string(secret) redisStore, err := NewRedisSessionStore(&opts.Session, &opts.Cookie) require.NoError(t, err) err = redisStore.Save( diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 1a60aa0d..70aa4e52 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -353,24 +353,11 @@ var _ = Describe("NewSessionStore", func() { SameSite: "strict", } - var err error - ss, err = sessions.NewSessionStore(opts, cookieOpts) - Expect(err).ToNot(HaveOccurred()) - }) - - SessionStoreInterfaceTests(persistent) - }) - - Context("with a cipher", func() { - BeforeEach(func() { + // A secret is required but not defaulted secret := make([]byte, 32) _, err := rand.Read(secret) Expect(err).ToNot(HaveOccurred()) cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret) - cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret)) - Expect(err).ToNot(HaveOccurred()) - Expect(cipher).ToNot(BeNil()) - opts.Cipher = cipher ss, err = sessions.NewSessionStore(opts, cookieOpts) Expect(err).ToNot(HaveOccurred()) @@ -384,9 +371,16 @@ var _ = Describe("NewSessionStore", func() { ss = nil opts = &options.SessionOptions{} + // A secret is required to create a Cipher, validation ensures it is the correct + // length before a session store is initialised. + secret := make([]byte, 32) + _, err := rand.Read(secret) + Expect(err).ToNot(HaveOccurred()) + // Set default options in CookieOptions cookieOpts = &options.CookieOptions{ Name: "_oauth2_proxy", + Secret: base64.URLEncoding.EncodeToString(secret), Path: "/", Expire: time.Duration(168) * time.Hour, Refresh: time.Duration(1) * time.Hour, @@ -423,6 +417,19 @@ var _ = Describe("NewSessionStore", func() { Context("the cookie.SessionStore", func() { RunSessionTests(false) }) + + Context("with an invalid cookie secret", func() { + BeforeEach(func() { + cookieOpts.Secret = "invalid" + }) + + It("returns an error", func() { + ss, err := sessions.NewSessionStore(opts, cookieOpts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("error initialising cipher: crypto/aes: invalid key size 7")) + Expect(ss).To(BeNil()) + }) + }) }) Context("with type 'redis'", func() { @@ -447,6 +454,19 @@ var _ = Describe("NewSessionStore", func() { Context("the redis.SessionStore", func() { RunSessionTests(true) }) + + Context("with an invalid cookie secret", func() { + BeforeEach(func() { + cookieOpts.Secret = "invalid" + }) + + It("returns an error", func() { + ss, err := sessions.NewSessionStore(opts, cookieOpts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("error initialising cipher: crypto/aes: invalid key size 7")) + Expect(ss).To(BeNil()) + }) + }) }) Context("with an invalid type", func() { diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 44a3e758..611ea92b 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -21,7 +21,6 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" - "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/providers" ) @@ -37,8 +36,6 @@ func Validate(o *options.Options) error { } msgs := make([]string, 0) - - var cipher encryption.Cipher if o.Cookie.Secret == "" { msgs = append(msgs, "missing setting: cookie-secret") } else { @@ -60,12 +57,6 @@ func Validate(o *options.Options) error { msgs = append(msgs, fmt.Sprintf("Cookie secret must be 16, 24, or 32 bytes to create an AES cipher. Got %d bytes.%s", len(encryption.SecretBytes(o.Cookie.Secret)), suffix)) - } else { - var err error - cipher, err = encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(o.Cookie.Secret)) - if err != nil { - msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err)) - } } } @@ -218,14 +209,6 @@ func Validate(o *options.Options) error { } msgs = parseProviderInfo(o, msgs) - o.Session.Cipher = cipher - sessionStore, err := sessions.NewSessionStore(&o.Session, &o.Cookie) - if err != nil { - msgs = append(msgs, fmt.Sprintf("error initialising session storage: %v", err)) - } else { - o.SetSessionStore(sessionStore) - } - if o.Cookie.Refresh >= o.Cookie.Expire { msgs = append(msgs, fmt.Sprintf( "cookie_refresh (%s) must be less than "+