Merge pull request #1 from miggo-io/fix-readiness-check
add graceful shutdown
This commit is contained in:
commit
7102fe1853
18
main.go
18
main.go
|
|
@ -1,9 +1,12 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"github.com/ghodss/yaml"
|
"github.com/ghodss/yaml"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
|
|
@ -54,12 +57,23 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile)
|
validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile)
|
||||||
oauthproxy, err := NewOAuthProxy(opts, validator)
|
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Observe signals in background goroutine.
|
||||||
|
go func() {
|
||||||
|
sigint := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
||||||
|
<-sigint
|
||||||
|
cancelFunc() // cancel the context
|
||||||
|
}()
|
||||||
|
|
||||||
|
oauthproxy, err := NewOAuthProxy(ctx, opts, validator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
|
logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := oauthproxy.Start(); err != nil {
|
if err := oauthproxy.Start(ctx); err != nil {
|
||||||
logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
|
logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,8 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
@ -117,7 +114,7 @@ type OAuthProxy struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
|
||||||
func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) {
|
func NewOAuthProxy(ctx context.Context, opts *options.Options, validator func(string) bool) (*OAuthProxy, error) {
|
||||||
sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie)
|
sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error initialising session store: %v", err)
|
return nil, fmt.Errorf("error initialising session store: %v", err)
|
||||||
|
|
@ -200,7 +197,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
preAuthChain, err := buildPreAuthChain(opts, sessionStore)
|
preAuthChain, err := buildPreAuthChain(ctx, opts, sessionStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
|
return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -258,23 +255,13 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) Start() error {
|
func (p *OAuthProxy) Start(ctx context.Context) error {
|
||||||
if p.server == nil {
|
if p.server == nil {
|
||||||
// We have to call setupServer before Start is called.
|
// We have to call setupServer before Start is called.
|
||||||
// If this doesn't happen it's a programming error.
|
// If this doesn't happen it's a programming error.
|
||||||
panic("server has not been initialised")
|
panic("server has not been initialised")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
// Observe signals in background goroutine.
|
|
||||||
go func() {
|
|
||||||
sigint := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
|
|
||||||
<-sigint
|
|
||||||
cancel() // cancel the context
|
|
||||||
}()
|
|
||||||
|
|
||||||
return p.server.Start(ctx)
|
return p.server.Start(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -284,6 +271,7 @@ func (p *OAuthProxy) setupServer(opts *options.Options) error {
|
||||||
BindAddress: opts.Server.BindAddress,
|
BindAddress: opts.Server.BindAddress,
|
||||||
SecureBindAddress: opts.Server.SecureBindAddress,
|
SecureBindAddress: opts.Server.SecureBindAddress,
|
||||||
TLS: opts.Server.TLS,
|
TLS: opts.Server.TLS,
|
||||||
|
ShutdownDuration: opts.ShutdownDuration,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option: AllowQuerySemicolons
|
// Option: AllowQuerySemicolons
|
||||||
|
|
@ -353,7 +341,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {
|
||||||
// buildPreAuthChain constructs a chain that should process every request before
|
// buildPreAuthChain constructs a chain that should process every request before
|
||||||
// the OAuth2 Proxy authentication logic kicks in.
|
// the OAuth2 Proxy authentication logic kicks in.
|
||||||
// For example forcing HTTPS or health checks.
|
// For example forcing HTTPS or health checks.
|
||||||
func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) {
|
func buildPreAuthChain(ctx context.Context, opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) {
|
||||||
chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader))
|
chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader))
|
||||||
|
|
||||||
if opts.ForceHTTPS {
|
if opts.ForceHTTPS {
|
||||||
|
|
@ -374,17 +362,18 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt
|
||||||
|
|
||||||
// To silence logging of health checks, register the health check handler before
|
// To silence logging of health checks, register the health check handler before
|
||||||
// the logging handler
|
// the logging handler
|
||||||
|
readinessCheck := middleware.NewReadynessCheck(ctx, opts.ReadyPath, sessionStore)
|
||||||
if opts.Logging.SilencePing {
|
if opts.Logging.SilencePing {
|
||||||
chain = chain.Append(
|
chain = chain.Append(
|
||||||
middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),
|
middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),
|
||||||
middleware.NewReadynessCheck(opts.ReadyPath, sessionStore),
|
readinessCheck,
|
||||||
middleware.NewRequestLogger(),
|
middleware.NewRequestLogger(),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
chain = chain.Append(
|
chain = chain.Append(
|
||||||
middleware.NewRequestLogger(),
|
middleware.NewRequestLogger(),
|
||||||
middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),
|
middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),
|
||||||
middleware.NewReadynessCheck(opts.ReadyPath, sessionStore),
|
readinessCheck,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ func TestRobotsTxt(t *testing.T) {
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -108,7 +108,7 @@ func Test_redeemCode(t *testing.T) {
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -160,7 +160,7 @@ func Test_enrichSession(t *testing.T) {
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -229,7 +229,7 @@ func TestBasicAuthPassword(t *testing.T) {
|
||||||
providerURL, _ := url.Parse(providerServer.URL)
|
providerURL, _ := url.Parse(providerServer.URL)
|
||||||
const emailAddress = "john.doe@example.com"
|
const emailAddress = "john.doe@example.com"
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
|
||||||
return email == emailAddress
|
return email == emailAddress
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -291,7 +291,7 @@ func TestPassGroupsHeadersWithGroups(t *testing.T) {
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
|
||||||
return email == emailAddress
|
return email == emailAddress
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
@ -387,7 +387,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe
|
||||||
|
|
||||||
testProvider := NewTestProvider(providerURL, emailAddress)
|
testProvider := NewTestProvider(providerURL, emailAddress)
|
||||||
testProvider.ValidToken = opts.ValidToken
|
testProvider.ValidToken = opts.ValidToken
|
||||||
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
|
patt.proxy, err = NewOAuthProxy(context.Background(), patt.opts, func(email string) bool {
|
||||||
return email == emailAddress
|
return email == emailAddress
|
||||||
})
|
})
|
||||||
patt.proxy.provider = testProvider
|
patt.proxy.provider = testProvider
|
||||||
|
|
@ -592,7 +592,7 @@ func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool {
|
sipTest.proxy, err = NewOAuthProxy(context.Background(), sipTest.opts, func(email string) bool {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -628,7 +628,7 @@ func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -676,7 +676,7 @@ func ManualSignInWithCredentials(t *testing.T, user, pass string) int {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -826,7 +826,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
|
||||||
return pcTest.validateUser
|
return pcTest.validateUser
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1200,7 +1200,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||||
err := validation.Validate(pcTest.opts)
|
err := validation.Validate(pcTest.opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
|
||||||
return pcTest.validateUser
|
return pcTest.validateUser
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1293,7 +1293,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
|
||||||
err := validation.Validate(pcTest.opts)
|
err := validation.Validate(pcTest.opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
|
||||||
return pcTest.validateUser
|
return pcTest.validateUser
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1373,7 +1373,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
|
||||||
err := validation.Validate(pcTest.opts)
|
err := validation.Validate(pcTest.opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
|
pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
|
||||||
return pcTest.validateUser
|
return pcTest.validateUser
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1429,7 +1429,7 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
||||||
|
|
||||||
upstreamURL, _ := url.Parse(upstreamServer.URL)
|
upstreamURL, _ := url.Parse(upstreamServer.URL)
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return false })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -1554,7 +1554,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), st.opts, func(email string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -1642,7 +1642,7 @@ func newAjaxRequestTest(forceJSONErrors bool) (*ajaxRequestTest, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool {
|
test.proxy, err = NewOAuthProxy(context.Background(), test.opts, func(email string) bool {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1958,7 +1958,7 @@ func Test_noCacheHeaders(t *testing.T) {
|
||||||
opts.SkipAuthRegex = []string{".*"}
|
opts.SkipAuthRegex = []string{".*"}
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -2233,7 +2233,7 @@ func TestTrustedIPs(t *testing.T) {
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|
@ -2478,7 +2478,7 @@ func TestApiRoutes(t *testing.T) {
|
||||||
opts.SkipProviderButton = true
|
opts.SkipProviderButton = true
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -2561,7 +2561,7 @@ func TestAllowedRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -2660,7 +2660,7 @@ func TestAllowedRequestWithForwardedUriHeader(t *testing.T) {
|
||||||
}
|
}
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -2759,7 +2759,7 @@ func TestAllowedRequestNegateWithoutMethod(t *testing.T) {
|
||||||
}
|
}
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -2859,7 +2859,7 @@ func TestAllowedRequestNegateWithMethod(t *testing.T) {
|
||||||
}
|
}
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -3460,7 +3460,7 @@ func TestGetOAuthRedirectURI(t *testing.T) {
|
||||||
err := validation.Validate(baseOpts)
|
err := validation.Validate(baseOpts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(tt.setupOpts(baseOpts), func(string) bool { return true })
|
proxy, err := NewOAuthProxy(context.Background(), tt.setupOpts(baseOpts), func(string) bool { return true })
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equalf(t, tt.want, proxy.getOAuthRedirectURI(tt.req), "getOAuthRedirectURI(%v)", tt.req)
|
assert.Equalf(t, tt.want, proxy.getOAuthRedirectURI(tt.req), "getOAuthRedirectURI(%v)", tt.req)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package options
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip"
|
ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip"
|
||||||
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc"
|
internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc"
|
||||||
|
|
@ -29,6 +30,8 @@ type Options struct {
|
||||||
RawRedirectURL string `flag:"redirect-url" cfg:"redirect_url"`
|
RawRedirectURL string `flag:"redirect-url" cfg:"redirect_url"`
|
||||||
RelativeRedirectURL bool `flag:"relative-redirect-url" cfg:"relative_redirect_url"`
|
RelativeRedirectURL bool `flag:"relative-redirect-url" cfg:"relative_redirect_url"`
|
||||||
|
|
||||||
|
ShutdownDuration time.Duration `flag:"shutdown-duration" cfg:"shutdown_duration"`
|
||||||
|
|
||||||
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
|
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
|
||||||
EmailDomains []string `flag:"email-domain" cfg:"email_domains"`
|
EmailDomains []string `flag:"email-domain" cfg:"email_domains"`
|
||||||
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"`
|
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"`
|
||||||
|
|
@ -161,6 +164,7 @@ func NewFlagSet() *pflag.FlagSet {
|
||||||
flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option")
|
flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option")
|
||||||
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
|
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
|
||||||
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
|
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
|
||||||
|
flagSet.Duration("shutdown-duration", 0, "Amount of time to continue serving traffic after receiving an exit signal with readiness endpoint set to false.")
|
||||||
|
|
||||||
flagSet.AddFlagSet(cookieFlagSet())
|
flagSet.AddFlagSet(cookieFlagSet())
|
||||||
flagSet.AddFlagSet(loggingFlagSet())
|
flagSet.AddFlagSet(loggingFlagSet())
|
||||||
|
|
|
||||||
|
|
@ -40,12 +40,16 @@ type Opts struct {
|
||||||
|
|
||||||
// Let testing infrastructure circumvent parsing file descriptors
|
// Let testing infrastructure circumvent parsing file descriptors
|
||||||
fdFiles []*os.File
|
fdFiles []*os.File
|
||||||
|
|
||||||
|
// Graceful shutdown duration
|
||||||
|
ShutdownDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Server from the options given.
|
// NewServer creates a new Server from the options given.
|
||||||
func NewServer(opts Opts) (Server, error) {
|
func NewServer(opts Opts) (Server, error) {
|
||||||
s := &server{
|
s := &server{
|
||||||
handler: opts.Handler,
|
handler: opts.Handler,
|
||||||
|
shutdownDuration: opts.ShutdownDuration,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.fdFiles) > 0 {
|
if len(opts.fdFiles) > 0 {
|
||||||
|
|
@ -71,6 +75,9 @@ type server struct {
|
||||||
|
|
||||||
// ensure activation.Files are called once
|
// ensure activation.Files are called once
|
||||||
fdFiles []*os.File
|
fdFiles []*os.File
|
||||||
|
|
||||||
|
// Graceful shutdown duration
|
||||||
|
shutdownDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupListener sets the server listener if the HTTP server is enabled.
|
// setupListener sets the server listener if the HTTP server is enabled.
|
||||||
|
|
@ -214,10 +221,16 @@ func (s *server) startServer(ctx context.Context, listener net.Listener) error {
|
||||||
|
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
<-groupCtx.Done()
|
<-groupCtx.Done()
|
||||||
|
logger.Printf("Context canceled. Waiting %s before shutting down the listeners.", s.shutdownDuration)
|
||||||
|
|
||||||
|
time.Sleep(s.shutdownDuration)
|
||||||
|
|
||||||
|
logger.Printf("Shutting down listener.")
|
||||||
|
|
||||||
if err := srv.Shutdown(context.Background()); err != nil {
|
if err := srv.Shutdown(context.Background()); err != nil {
|
||||||
return fmt.Errorf("error shutting down server: %v", err)
|
return fmt.Errorf("error shutting down server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,19 @@ type Verifiable interface {
|
||||||
|
|
||||||
// NewReadynessCheck returns a middleware that performs deep health checks
|
// NewReadynessCheck returns a middleware that performs deep health checks
|
||||||
// (verifies the connection to any underlying store) on a specific `path`
|
// (verifies the connection to any underlying store) on a specific `path`
|
||||||
func NewReadynessCheck(path string, verifiable Verifiable) alice.Constructor {
|
func NewReadynessCheck(ctx context.Context, path string, verifiable Verifiable) alice.Constructor {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return readynessCheck(path, verifiable, next)
|
return readynessCheck(ctx, path, verifiable, next)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func readynessCheck(path string, verifiable Verifiable, next http.Handler) http.Handler {
|
func readynessCheck(ctx context.Context, path string, verifiable Verifiable, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
fmt.Fprintf(rw, "Shutting down")
|
||||||
|
return
|
||||||
|
}
|
||||||
if path != "" && req.URL.EscapedPath() == path {
|
if path != "" && req.URL.EscapedPath() == path {
|
||||||
if err := verifiable.VerifyConnection(req.Context()); err != nil {
|
if err := verifiable.VerifyConnection(req.Context()); err != nil {
|
||||||
rw.WriteHeader(http.StatusInternalServerError)
|
rw.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,9 @@ var _ = Describe("ReadynessCheck suite", func() {
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
handler := NewReadynessCheck(in.readyPath, in.healthVerifiable)(http.NotFoundHandler())
|
ctx := context.Background()
|
||||||
|
|
||||||
|
handler := NewReadynessCheck(ctx, in.readyPath, in.healthVerifiable)(http.NotFoundHandler())
|
||||||
handler.ServeHTTP(rw, req)
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
Expect(rw.Code).To(Equal(in.expectedStatus))
|
Expect(rw.Code).To(Equal(in.expectedStatus))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue