diff --git a/http.go b/http.go deleted file mode 100644 index 34700380..00000000 --- a/http.go +++ /dev/null @@ -1,136 +0,0 @@ -package main - -import ( - "context" - "crypto/tls" - "errors" - "net" - "net/http" - "strings" - "time" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" -) - -// Server represents an HTTP server -type Server struct { - Handler http.Handler - Opts *options.Options - stop chan struct{} // channel for waiting shutdown -} - -// ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options -func (s *Server) ListenAndServe() { - if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { - s.ServeHTTPS() - } else { - s.ServeHTTP() - } -} - -// ServeHTTP constructs a net.Listener and starts handling HTTP requests -func (s *Server) ServeHTTP() { - HTTPAddress := s.Opts.HTTPAddress - var scheme string - - i := strings.Index(HTTPAddress, "://") - if i > -1 { - scheme = HTTPAddress[0:i] - } - - var networkType string - switch scheme { - case "", "http": - networkType = "tcp" - default: - networkType = scheme - } - - slice := strings.SplitN(HTTPAddress, "//", 2) - listenAddr := slice[len(slice)-1] - - listener, err := net.Listen(networkType, listenAddr) - if err != nil { - logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err) - } - logger.Printf("HTTP: listening on %s", listenAddr) - s.serve(listener) - logger.Printf("HTTP: closing %s", listener.Addr()) -} - -// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests -func (s *Server) ServeHTTPS() { - addr := s.Opts.HTTPSAddress - config := &tls.Config{ - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS13, - } - if config.NextProtos == nil { - config.NextProtos = []string{"http/1.1"} - } - - var err error - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile) - if err != nil { - logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err) - } - - ln, err := net.Listen("tcp", addr) - if err != nil { - logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err) - } - logger.Printf("HTTPS: listening on %s", ln.Addr()) - - tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) - s.serve(tlsListener) - logger.Printf("HTTPS: closing %s", tlsListener.Addr()) -} - -func (s *Server) serve(listener net.Listener) { - srv := &http.Server{Handler: s.Handler} - - // See https://golang.org/pkg/net/http/#Server.Shutdown - idleConnsClosed := make(chan struct{}) - go func() { - <-s.stop // wait notification for stopping server - - // We received an interrupt signal, shut down. - if err := srv.Shutdown(context.Background()); err != nil { - // Error from closing listeners, or context timeout: - logger.Printf("HTTP server Shutdown: %v", err) - } - close(idleConnsClosed) - }() - - err := srv.Serve(listener) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - logger.Errorf("ERROR: http.Serve() - %s", err) - } - <-idleConnsClosed -} - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { - tc, err := ln.AcceptTCP() - if err != nil { - return nil, err - } - err = tc.SetKeepAlive(true) - if err != nil { - logger.Printf("Error setting Keep-Alive: %v", err) - } - err = tc.SetKeepAlivePeriod(3 * time.Minute) - if err != nil { - logger.Printf("Error setting Keep-Alive period: %v", err) - } - return tc, nil -} diff --git a/http_test.go b/http_test.go deleted file mode 100644 index f4e12843..00000000 --- a/http_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package main - -import ( - "net/http" - "sync" - "testing" - "time" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" - "github.com/stretchr/testify/assert" -) - -func TestGracefulShutdown(t *testing.T) { - opts := options.NewOptions() - stop := make(chan struct{}, 1) - srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop} - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - srv.ServeHTTP() - }() - - stop <- struct{}{} // emulate catching signals - - // An idiomatic for sync.WaitGroup with timeout - c := make(chan struct{}) - go func() { - defer close(c) - wg.Wait() - }() - select { - case <-c: - case <-time.After(1 * time.Second): - t.Fatal("Server should return gracefully but timeout has occurred") - } - - assert.Len(t, stop, 0) // check if stop chan is empty -} diff --git a/main.go b/main.go index 924e875e..97f2c5a0 100644 --- a/main.go +++ b/main.go @@ -1,20 +1,15 @@ package main import ( - "context" "fmt" "math/rand" - "net/http" "os" - "os/signal" "runtime" - "syscall" "time" "github.com/ghodss/yaml" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" "github.com/spf13/pflag" ) @@ -67,54 +62,9 @@ func main() { rand.Seed(time.Now().UnixNano()) - oauthProxyStop := make(chan struct{}, 1) - metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop) - - s := &Server{ - Handler: oauthproxy, - Opts: opts, - stop: oauthProxyStop, + if err := oauthproxy.Start(); err != nil { + logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err) } - // Observe signals in background goroutine. - go func() { - sigint := make(chan os.Signal, 1) - signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) - <-sigint - s.stop <- struct{}{} // notify having caught signal stop oauthproxy - close(metricsStop) // and the metrics endpoint - }() - s.ListenAndServe() -} - -// startMetricsServer will start the metrics server on the specified address. -// It always return a channel to signal stop even when it does not run. -func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} { - stop := make(chan struct{}, 1) - - // Attempt to setup the metrics endpoint if we have an address - if address != "" { - s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler} - go func() { - // ListenAndServe always returns a non-nil error. After Shutdown or - // Close, the returned error is ErrServerClosed - if err := s.ListenAndServe(); err != http.ErrServerClosed { - logger.Println(err) - // Stop the metrics shutdown go routine - close(stop) - // Stop the oauthproxy server, we have encounter an unexpected error - close(oauthProxyStop) - } - }() - - go func() { - <-stop - if err := s.Shutdown(context.Background()); err != nil { - logger.Print(err) - } - }() - } - - return stop } // loadConfiguration will load in the user's configuration. diff --git a/main_test.go b/main_test.go index a91940e7..3273abfa 100644 --- a/main_test.go +++ b/main_test.go @@ -15,6 +15,7 @@ import ( var _ = Describe("Configuration Loading Suite", func() { const testLegacyConfig = ` +http_address="127.0.0.1:4180" upstreams="http://httpbin" set_basic_auth="true" basic_auth_password="super-secret-password" @@ -54,10 +55,11 @@ injectResponseHeaders: prefix: "Basic " basicAuthPassword: value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk +server: + bindAddress: "127.0.0.1:4180" ` const testCoreConfig = ` -http_address="0.0.0.0:4180" cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" provider="oidc" email_domains="example.com" @@ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback" opts, err := options.NewLegacyOptions().ToOptions() Expect(err).ToNot(HaveOccurred()) - opts.HTTPAddress = "0.0.0.0:4180" opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" opts.ProviderType = "oidc" opts.EmailDomains = []string{"example.com"} @@ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback" configContent: testCoreConfig, alphaConfigContent: testAlphaConfig + ":", expectedOptions: func() *options.Options { return nil }, - expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"), + expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"), }), Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{ configContent: testCoreConfig + "unknown_field=\"something\"", diff --git a/oauthproxy.go b/oauthproxy.go index 43c64525..0a9669f3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -8,8 +8,11 @@ import ( "net" "net/http" "net/url" + "os" + "os/signal" "regexp" "strings" + "syscall" "time" "github.com/justinas/alice" @@ -21,6 +24,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" + proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" @@ -102,6 +106,7 @@ type OAuthProxy struct { headersChain alice.Chain preAuthChain alice.Chain pageWriter pagewriter.Writer + server proxyhttp.Server } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided @@ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, fmt.Errorf("could not build headers chain: %v", err) } - return &OAuthProxy{ + p := &OAuthProxy{ CookieName: opts.Cookie.Name, CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), CookieSeed: opts.Cookie.Secret, @@ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr headersChain: headersChain, preAuthChain: preAuthChain, pageWriter: pageWriter, - }, nil + } + + if err := p.setupServer(opts); err != nil { + return nil, fmt.Errorf("error setting up server: %v", err) + } + + return p, nil +} + +func (p *OAuthProxy) Start() error { + if p.server == nil { + // We have to call setupServer before Start is called. + // If this doesn't happen it's a programming error. + panic("server has not been initialised") + } + + ctx, cancel := context.WithCancel(context.Background()) + + // Observe signals in background goroutine. + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) + <-sigint + cancel() // cancel the context + }() + + return p.server.Start(ctx) +} + +func (p *OAuthProxy) setupServer(opts *options.Options) error { + serverOpts := proxyhttp.Opts{ + Handler: p, + BindAddress: opts.Server.BindAddress, + SecureBindAddress: opts.Server.SecureBindAddress, + TLS: opts.Server.TLS, + } + + appServer, err := proxyhttp.NewServer(serverOpts) + if err != nil { + return fmt.Errorf("could not build app server: %v", err) + } + + metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{ + Handler: middleware.DefaultMetricsHandler, + BindAddress: opts.MetricsServer.BindAddress, + SecureBindAddress: opts.MetricsServer.BindAddress, + TLS: opts.MetricsServer.TLS, + }) + if err != nil { + return fmt.Errorf("could not build metrics server: %v", err) + } + + p.server = proxyhttp.NewServerGroup(appServer, metricsServer) + return nil } // buildPreAuthChain constructs a chain that should process every request before @@ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { chain := alice.New(middleware.NewScope(opts.ReverseProxy)) if opts.ForceHTTPS { - _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) + _, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress) if err != nil { - return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err) + return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err) } chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 06014c9b..a2805c34 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options { }, }, } + return opts }