diff --git a/.golangci.yml b/.golangci.yml index 8ddb2187..8cab4291 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,6 +9,7 @@ linters: - deadcode - gofmt - goimports + - gosec - gosimple - staticcheck - structcheck @@ -33,6 +34,7 @@ issues: - bodyclose - unconvert - gocritic + - gosec # If we have tests in shared test folders, these can be less strictly linted - path: tests/.*_tests\.go linters: diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b72f595..23c6e9a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ## Changes since v6.0.0 +- [#690](https://github.com/oauth2-proxy/oauth2-proxy/pull/690) Address GoSec security findings & remediate (@NickMeves) - [#689](https://github.com/oauth2-proxy/oauth2-proxy/pull/689) Fix finicky logging_handler_test from time drift (@NickMeves) - [#699](https://github.com/oauth2-proxy/oauth2-proxy/pull/699) Align persistence ginkgo tests with conventions (@NickMeves) - [#696](https://github.com/oauth2-proxy/oauth2-proxy/pull/696) Preserve query when building redirect diff --git a/http.go b/http.go index 7e1215f7..41198de4 100644 --- a/http.go +++ b/http.go @@ -119,12 +119,18 @@ type tcpKeepAliveListener struct { *net.TCPListener } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { +func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { tc, err := ln.AcceptTCP() if err != nil { - return + return nil, err + } + err = tc.SetKeepAlive(true) + if err != nil { + logger.Printf("Error setting Keep-Alive: %v", err) + } + err = tc.SetKeepAlivePeriod(3 * time.Minute) + if err != nil { + logger.Printf("Error setting Keep-Alive period: %v", err) } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) return tc, nil } diff --git a/main.go b/main.go index 24d48072..ef9ac44e 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "os" "os/signal" "runtime" - "strings" "syscall" "time" @@ -25,7 +24,11 @@ func main() { config := flagSet.String("config", "", "path to config file") showVersion := flagSet.Bool("version", false, "print version string") - flagSet.Parse(os.Args[1:]) + err := flagSet.Parse(os.Args[1:]) + if err != nil { + logger.Printf("ERROR: Failed to parse flags: %v", err) + os.Exit(1) + } if *showVersion { fmt.Printf("oauth2-proxy %s (built with %s)\n", VERSION, runtime.Version()) @@ -33,7 +36,7 @@ func main() { } legacyOpts := options.NewLegacyOptions() - err := options.Load(*config, flagSet, legacyOpts) + err = options.Load(*config, flagSet, legacyOpts) if err != nil { logger.Printf("ERROR: Failed to load config: %v", err) os.Exit(1) @@ -58,20 +61,6 @@ func main() { os.Exit(1) } - if len(opts.Banner) >= 1 { - if opts.Banner == "-" { - oauthproxy.SignInMessage = "" - } else { - oauthproxy.SignInMessage = opts.Banner - } - } else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { - if len(opts.EmailDomains) > 1 { - oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) - } else if opts.EmailDomains[0] != "*" { - oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) - } - } - rand.Seed(time.Now().UnixNano()) chain := alice.New() diff --git a/oauthproxy.go b/oauthproxy.go index 64df8eb1..ac9e0565 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -213,6 +213,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr trustedIPs: trustedIPs, Banner: opts.Banner, Footer: opts.Footer, + SignInMessage: buildSignInMessage(opts), basicAuthValidator: basicAuthValidator, displayHtpasswdForm: basicAuthValidator != nil, @@ -255,6 +256,24 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt return chain } +func buildSignInMessage(opts *options.Options) string { + var msg string + if len(opts.Banner) >= 1 { + if opts.Banner == "-" { + msg = "" + } else { + msg = opts.Banner + } + } else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { + if len(opts.EmailDomains) > 1 { + msg = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) + } else if opts.EmailDomains[0] != "*" { + msg = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) + } + } + return msg +} + // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will // redirect clients to once authenticated func (p *OAuthProxy) GetRedirectURI(host string) string { @@ -363,8 +382,13 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s // RobotsTxt disallows scraping pages from the OAuthProxy func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { + _, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") + if err != nil { + logger.Printf("Error writing robots.txt: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + return + } rw.WriteHeader(http.StatusOK) - fmt.Fprintf(rw, "User-agent: *\nDisallow: /") } // ErrorPage writes an error response @@ -379,19 +403,28 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m Message: message, ProxyPrefix: p.ProxyPrefix, } - p.templates.ExecuteTemplate(rw, "error.html", t) + err := p.templates.ExecuteTemplate(rw, "error.html", t) + if err != nil { + logger.Printf("Error rendering error.html template: %v", err) + http.Error(rw, "Internal Server Error", http.StatusInternalServerError) + } } // SignInPage writes the sing in template to the response func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { prepareNoCache(rw) - p.ClearSessionCookie(rw, req) + err := p.ClearSessionCookie(rw, req) + if err != nil { + logger.Printf("Error clearing session cookie: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + return + } rw.WriteHeader(code) redirectURL, err := p.GetRedirect(req) if err != nil { - logger.Printf("Error obtaining redirect: %s", err.Error()) - p.ErrorPage(rw, 500, "Internal Error", err.Error()) + logger.Printf("Error obtaining redirect: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } @@ -399,6 +432,8 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code redirectURL = "/" } + // We allow unescaped template.HTML since it is user configured options + /* #nosec G203 */ t := struct { ProviderName string SignInMessage template.HTML @@ -419,11 +454,15 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code if p.providerNameOverride != "" { t.ProviderName = p.providerNameOverride } - p.templates.ExecuteTemplate(rw, "sign_in.html", t) + err = p.templates.ExecuteTemplate(rw, "sign_in.html", t) + if err != nil { + logger.Printf("Error rendering sign_in.html template: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + } } // ManualSignIn handles basic auth logins to the proxy -func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { +func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) { if req.Method != "POST" || p.basicAuthValidator == nil { return "", false } @@ -627,15 +666,20 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { redirect, err := p.GetRedirect(req) if err != nil { - logger.Printf("Error obtaining redirect: %s", err.Error()) - p.ErrorPage(rw, 500, "Internal Error", err.Error()) + logger.Printf("Error obtaining redirect: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - user, ok := p.ManualSignIn(rw, req) + user, ok := p.ManualSignIn(req) if ok { session := &sessionsapi.SessionState{User: user} - p.SaveSession(rw, req, session) + err = p.SaveSession(rw, req, session) + if err != nil { + logger.Printf("Error saving session: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + return + } http.Redirect(rw, req, redirect, http.StatusFound) } else { if p.SkipProviderButton { @@ -663,18 +707,27 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { } rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(http.StatusOK) - json.NewEncoder(rw).Encode(userInfo) + err = json.NewEncoder(rw).Encode(userInfo) + if err != nil { + logger.Printf("Error encoding user info: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + } } // SignOut sends a response to clear the authentication cookie func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { redirect, err := p.GetRedirect(req) if err != nil { - logger.Printf("Error obtaining redirect: %s", err.Error()) - p.ErrorPage(rw, 500, "Internal Error", err.Error()) + logger.Printf("Error obtaining redirect: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) + return + } + err = p.ClearSessionCookie(rw, req) + if err != nil { + logger.Printf("Error clearing session cookie: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - p.ClearSessionCookie(rw, req) http.Redirect(rw, req, redirect, http.StatusFound) } @@ -683,15 +736,15 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { prepareNoCache(rw) nonce, err := encryption.Nonce() if err != nil { - logger.Printf("Error obtaining nonce: %s", err.Error()) - p.ErrorPage(rw, 500, "Internal Error", err.Error()) + logger.Printf("Error obtaining nonce: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } p.SetCSRFCookie(rw, req, nonce) redirect, err := p.GetRedirect(req) if err != nil { - logger.Printf("Error obtaining redirect: %s", err.Error()) - p.ErrorPage(rw, 500, "Internal Error", err.Error()) + logger.Printf("Error obtaining redirect: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } redirectURI := p.GetRedirectURI(req.Host) @@ -706,42 +759,42 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { // finish the oauth cycle err := req.ParseForm() if err != nil { - logger.Printf("Error while parsing OAuth2 callback: %s" + err.Error()) - p.ErrorPage(rw, 500, "Internal Error", err.Error()) + logger.Printf("Error while parsing OAuth2 callback: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } errorString := req.Form.Get("error") if errorString != "" { - logger.Printf("Error while parsing OAuth2 callback: %s ", errorString) - p.ErrorPage(rw, 403, "Permission Denied", errorString) + logger.Printf("Error while parsing OAuth2 callback: %s", errorString) + p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", errorString) return } session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code")) if err != nil { - logger.Printf("Error redeeming code during OAuth2 callback: %s ", err.Error()) - p.ErrorPage(rw, 500, "Internal Error", "Internal Error") + logger.Printf("Error redeeming code during OAuth2 callback: %v", err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") return } s := strings.SplitN(req.Form.Get("state"), ":", 2) if len(s) != 2 { logger.Printf("Error while parsing OAuth2 state: invalid length") - p.ErrorPage(rw, 500, "Internal Error", "Invalid State") + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State") return } nonce := s[0] redirect := s[1] c, err := req.Cookie(p.CSRFCookieName) if err != nil { - logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable too obtain CSRF cookie") - p.ErrorPage(rw, 403, "Permission Denied", err.Error()) + logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie") + p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", err.Error()) return } p.ClearCSRFCookie(rw, req) if c.Value != nonce { - logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: csrf token mismatch, potential attack") - p.ErrorPage(rw, 403, "Permission Denied", "csrf failed") + logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: CSRF token mismatch, potential attack") + p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed") return } @@ -754,14 +807,14 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) err := p.SaveSession(rw, req, session) if err != nil { - logger.Printf("%s %s", remoteAddr, err) - p.ErrorPage(rw, 500, "Internal Error", "Internal Error") + logger.Printf("Error saving session state for %s: %v", remoteAddr, err) + p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } http.Redirect(rw, req, redirect, http.StatusFound) } else { logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unauthorized") - p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") + p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "Invalid Account") } } @@ -837,7 +890,10 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R if session != nil && session.Email != "" && !p.Validator(session.Email) { logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) // Invalid session, clear it - p.ClearSessionCookie(rw, req) + err := p.ClearSessionCookie(rw, req) + if err != nil { + logger.Printf("Error clearing session cookie: %v", err) + } return nil, ErrNeedsLogin } diff --git a/pkg/authentication/basic/htpasswd.go b/pkg/authentication/basic/htpasswd.go index ff87894f..321b7bae 100644 --- a/pkg/authentication/basic/htpasswd.go +++ b/pkg/authentication/basic/htpasswd.go @@ -1,7 +1,8 @@ package basic import ( - "crypto/sha1" + // We support SHA1 & bcrypt in HTPasswd + "crypto/sha1" // #nosec G505 "encoding/base64" "encoding/csv" "fmt" @@ -29,11 +30,17 @@ type sha1Pass string // NewHTPasswdValidator constructs an httpasswd based validator from the file // at the path given. func NewHTPasswdValidator(path string) (Validator, error) { - r, err := os.Open(path) + // We allow HTPasswd location via config options + r, err := os.Open(path) // #nosec G304 if err != nil { return nil, fmt.Errorf("could not open htpasswd file: %v", err) } - defer r.Close() + defer func(c io.Closer) { + cerr := c.Close() + if cerr != nil { + logger.Fatalf("error closing the htpasswd file: %v", cerr) + } + }(r) return newHtpasswd(r) } @@ -83,13 +90,17 @@ func (h *htpasswdMap) Validate(user string, password string) bool { return false } - switch real := realPassword.(type) { + switch rp := realPassword.(type) { case sha1Pass: - d := sha1.New() - d.Write([]byte(password)) - return string(real) == base64.StdEncoding.EncodeToString(d.Sum(nil)) + // We support SHA1 HTPasswd entries + d := sha1.New() // #nosec G401 + _, err := d.Write([]byte(password)) + if err != nil { + return false + } + return string(rp) == base64.StdEncoding.EncodeToString(d.Sum(nil)) case bcryptPass: - return bcrypt.CompareHashAndPassword([]byte(real), []byte(password)) == nil + return bcrypt.CompareHashAndPassword([]byte(rp), []byte(password)) == nil default: return false } diff --git a/pkg/encryption/utils.go b/pkg/encryption/utils.go index 26be6b24..269a89c6 100644 --- a/pkg/encryption/utils.go +++ b/pkg/encryption/utils.go @@ -2,7 +2,8 @@ package encryption import ( "crypto/hmac" - "crypto/sha1" + // TODO (@NickMeves): Remove SHA1 signed cookie support in V7 + "crypto/sha1" // #nosec G505 "crypto/sha256" "encoding/base64" "fmt" @@ -65,32 +66,44 @@ func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value } // SignedValue returns a cookie that is signed and can later be checked with Validate -func SignedValue(seed string, key string, value []byte, now time.Time) string { +func SignedValue(seed string, key string, value []byte, now time.Time) (string, error) { encodedValue := base64.URLEncoding.EncodeToString(value) timeStr := fmt.Sprintf("%d", now.Unix()) - sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr) + sig, err := cookieSignature(sha256.New, seed, key, encodedValue, timeStr) + if err != nil { + return "", err + } cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) - return cookieVal + return cookieVal, nil } -func cookieSignature(signer func() hash.Hash, args ...string) string { +func cookieSignature(signer func() hash.Hash, args ...string) (string, error) { h := hmac.New(signer, []byte(args[0])) for _, arg := range args[1:] { - h.Write([]byte(arg)) + _, err := h.Write([]byte(arg)) + if err != nil { + return "", err + } } var b []byte b = h.Sum(b) - return base64.URLEncoding.EncodeToString(b) + return base64.URLEncoding.EncodeToString(b), nil } func checkSignature(signature string, args ...string) bool { - checkSig := cookieSignature(sha256.New, args...) + checkSig, err := cookieSignature(sha256.New, args...) + if err != nil { + return false + } if checkHmac(signature, checkSig) { return true } - // TODO: After appropriate rollout window, remove support for SHA1 - legacySig := cookieSignature(sha1.New, args...) + // TODO (@NickMeves): Remove SHA1 signed cookie support in V7 + legacySig, err := cookieSignature(sha1.New, args...) + if err != nil { + return false + } return checkHmac(signature, legacySig) } diff --git a/pkg/encryption/utils_test.go b/pkg/encryption/utils_test.go index 15bc83fe..162c64ce 100644 --- a/pkg/encryption/utils_test.go +++ b/pkg/encryption/utils_test.go @@ -88,8 +88,10 @@ func TestSignAndValidate(t *testing.T) { value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded")) epoch := "123456789" - sha256sig := cookieSignature(sha256.New, seed, key, value, epoch) - sha1sig := cookieSignature(sha1.New, seed, key, value, epoch) + sha256sig, err := cookieSignature(sha256.New, seed, key, value, epoch) + assert.NoError(t, err) + sha1sig, err := cookieSignature(sha1.New, seed, key, value, epoch) + assert.NoError(t, err) assert.True(t, checkSignature(sha256sig, seed, key, value, epoch)) // This should be switched to False after fully deprecating SHA1 diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 9bfc2b3a..e0ce61dd 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -132,13 +132,19 @@ func (l *Logger) Output(calldepth int, message string) { l.mu.Lock() defer l.mu.Unlock() - l.stdLogTemplate.Execute(l.writer, stdLogMessageData{ + err := l.stdLogTemplate.Execute(l.writer, stdLogMessageData{ Timestamp: FormatTimestamp(now), File: file, Message: message, }) + if err != nil { + panic(err) + } - l.writer.Write([]byte("\n")) + _, err = l.writer.Write([]byte("\n")) + if err != nil { + panic(err) + } } // PrintAuthf writes auth info to the logger. Requires an http.Request to @@ -160,7 +166,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu l.mu.Lock() defer l.mu.Unlock() - l.authTemplate.Execute(l.writer, authLogMessageData{ + err := l.authTemplate.Execute(l.writer, authLogMessageData{ Client: client, Host: req.Host, Protocol: req.Proto, @@ -171,8 +177,14 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu Status: string(status), Message: fmt.Sprintf(format, a...), }) + if err != nil { + panic(err) + } - l.writer.Write([]byte("\n")) + _, err = l.writer.Write([]byte("\n")) + if err != nil { + panic(err) + } } // PrintReq writes request details to the Logger using the http.Request, @@ -208,7 +220,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. l.mu.Lock() defer l.mu.Unlock() - l.reqTemplate.Execute(l.writer, reqLogMessageData{ + err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ Client: client, Host: req.Host, Protocol: req.Proto, @@ -222,8 +234,14 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. UserAgent: fmt.Sprintf("%q", req.UserAgent()), Username: username, }) + if err != nil { + panic(err) + } - l.writer.Write([]byte("\n")) + _, err = l.writer.Write([]byte("\n")) + if err != nil { + panic(err) + } } // GetFileLineString will find the caller file and line number diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index dd4d1405..d299fdf8 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -55,7 +55,7 @@ type storedSessionLoader struct { } // loadSession attempts to load a session as identified by the request cookies. -// If no session is found, the request will be passed to the nex handler. +// If no session is found, the request will be passed to the next handler. // If a session was loader by a previous handler, it will not be replaced. func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -73,7 +73,10 @@ func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { // In the case when there was an error loading the session, // we should clear the session logger.Printf("Error loading cookied session: %v, removing session", err) - s.store.Clear(rw, req) + err = s.store.Clear(rw, req) + if err != nil { + logger.Printf("Error removing session: %v", err) + } } // Add the session to the scope if it was found diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 18717ed1..87d16863 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -44,8 +44,7 @@ func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessi if err != nil { return err } - s.setSessionCookie(rw, req, value, *ss.CreatedAt) - return nil + return s.setSessionCookie(rw, req, value, *ss.CreatedAt) } // Load reads sessions.SessionState information from Cookies within the @@ -114,24 +113,33 @@ func sessionFromCookie(v []byte, c encryption.Cipher) (s *sessions.SessionState, } // setSessionCookie adds the user's session cookie to the response -func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val []byte, created time.Time) { - for _, c := range s.makeSessionCookie(req, val, created) { +func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val []byte, created time.Time) error { + cookies, err := s.makeSessionCookie(req, val, created) + if err != nil { + return err + } + for _, c := range cookies { http.SetCookie(rw, c) } + return nil } // makeSessionCookie creates an http.Cookie containing the authenticated user's // authentication details -func (s *SessionStore) makeSessionCookie(req *http.Request, value []byte, now time.Time) []*http.Cookie { +func (s *SessionStore) makeSessionCookie(req *http.Request, value []byte, now time.Time) ([]*http.Cookie, error) { strValue := string(value) if strValue != "" { - strValue = encryption.SignedValue(s.Cookie.Secret, s.Cookie.Name, value, now) + var err error + strValue, err = encryption.SignedValue(s.Cookie.Secret, s.Cookie.Name, value, now) + if err != nil { + return nil, err + } } c := s.makeCookie(req, s.Cookie.Name, strValue, s.Cookie.Expire, now) if len(c.String()) > maxCookieLength { - return splitCookie(c) + return splitCookie(c), nil } - return []*http.Cookie{c} + return []*http.Cookie{c}, nil } func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go index 16a54b4e..4697ad21 100644 --- a/pkg/sessions/persistence/manager.go +++ b/pkg/sessions/persistence/manager.go @@ -48,9 +48,8 @@ func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.Se if err != nil { return err } - tckt.setCookie(rw, req, s) - return nil + return tckt.setCookie(rw, req, s) } // Load reads sessions.SessionState information from a session store. It will diff --git a/pkg/sessions/persistence/ticket.go b/pkg/sessions/persistence/ticket.go index da1097d9..abb78614 100644 --- a/pkg/sessions/persistence/ticket.go +++ b/pkg/sessions/persistence/ticket.go @@ -148,33 +148,42 @@ func (t *ticket) clearSession(clearer clearFunc) error { } // setCookie sets the encoded ticket as a cookie -func (t *ticket) setCookie(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) { - ticketCookie := t.makeCookie( +func (t *ticket) setCookie(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { + ticketCookie, err := t.makeCookie( req, t.encodeTicket(), t.options.Expire, *s.CreatedAt, ) + if err != nil { + return err + } http.SetCookie(rw, ticketCookie) + return nil } // clearCookie removes any cookies that would be where this ticket // would set them func (t *ticket) clearCookie(rw http.ResponseWriter, req *http.Request) { - clearCookie := t.makeCookie( + http.SetCookie(rw, cookies.MakeCookieFromOptions( req, + t.options.Name, "", + t.options, time.Hour*-1, time.Now(), - ) - http.SetCookie(rw, clearCookie) + )) } // makeCookie makes a cookie, signing the value if present -func (t *ticket) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie { +func (t *ticket) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) (*http.Cookie, error) { if value != "" { - value = encryption.SignedValue(t.options.Secret, t.options.Name, []byte(value), now) + var err error + value, err = encryption.SignedValue(t.options.Secret, t.options.Name, []byte(value), now) + if err != nil { + return nil, err + } } return cookies.MakeCookieFromOptions( req, @@ -183,7 +192,7 @@ func (t *ticket) makeCookie(req *http.Request, value string, expires time.Durati t.options, expires, now, - ) + ), nil } // makeCipher makes a AES-GCM cipher out of the ticket's secret diff --git a/pkg/sessions/tests/session_store_tests.go b/pkg/sessions/tests/session_store_tests.go index cf3f8173..f23028f3 100644 --- a/pkg/sessions/tests/session_store_tests.go +++ b/pkg/sessions/tests/session_store_tests.go @@ -311,11 +311,12 @@ func SessionStoreInterfaceTests(in *testInput) { BeforeEach(func() { By("Using a valid cookie with a different providers session encoding") broken := "BrokenSessionFromADifferentSessionImplementation" - value := encryption.SignedValue(in.cookieOpts.Secret, in.cookieOpts.Name, []byte(broken), time.Now()) + value, err := encryption.SignedValue(in.cookieOpts.Secret, in.cookieOpts.Name, []byte(broken), time.Now()) + Expect(err).ToNot(HaveOccurred()) cookie := cookiesapi.MakeCookieFromOptions(in.request, in.cookieOpts.Name, value, in.cookieOpts, in.cookieOpts.Expire, time.Now()) in.request.AddCookie(cookie) - err := in.ss().Save(in.response, in.request, in.session) + err = in.ss().Save(in.response, in.request, in.session) Expect(err).ToNot(HaveOccurred()) }) diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go index 5011b775..833b1399 100644 --- a/pkg/upstream/http.go +++ b/pkg/upstream/http.go @@ -103,6 +103,8 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr proxy.FlushInterval = 1 * time.Second } + // InsecureSkipVerify is a configurable option we allow + /* #nosec G402 */ if upstream.InsecureSkipTLSVerify { proxy.Transport = &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, @@ -156,6 +158,7 @@ func newWebSocketReverseProxy(u *url.URL, skipTLSVerify bool) http.Handler { wsURL := &url.URL{Scheme: wsScheme, Host: u.Host} wsProxy := wsutil.NewSingleHostReverseProxy(wsURL) + /* #nosec G402 */ if skipTLSVerify { wsProxy.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} } diff --git a/pkg/upstream/proxy.go b/pkg/upstream/proxy.go index 6c7c581b..403b9b91 100644 --- a/pkg/upstream/proxy.go +++ b/pkg/upstream/proxy.go @@ -85,6 +85,9 @@ func NewProxyErrorHandler(errorTemplate *template.Template, proxyPrefix string) Message: "Error proxying to upstream server", ProxyPrefix: proxyPrefix, } - errorTemplate.Execute(rw, data) + err := errorTemplate.Execute(rw, data) + if err != nil { + http.Error(rw, "Internal Server Error", http.StatusInternalServerError) + } } } diff --git a/pkg/util/util.go b/pkg/util/util.go index e0a3fd3b..4519fdb8 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -12,7 +12,8 @@ func GetCertPool(paths []string) (*x509.CertPool, error) { } pool := x509.NewCertPool() for _, path := range paths { - data, err := ioutil.ReadFile(path) + // Cert paths are a configurable option + data, err := ioutil.ReadFile(path) // #nosec G304 if err != nil { return nil, fmt.Errorf("certificate authority file (%s) could not be read - %s", path, err) } diff --git a/pkg/validation/logging.go b/pkg/validation/logging.go index 2c4754aa..83b965ce 100644 --- a/pkg/validation/logging.go +++ b/pkg/validation/logging.go @@ -13,13 +13,16 @@ func configureLogger(o options.Logging, msgs []string) []string { // Setup the log file if len(o.File.Filename) > 0 { // Validate that the file/dir can be written - file, err := os.OpenFile(o.File.Filename, os.O_WRONLY|os.O_CREATE, 0666) + file, err := os.OpenFile(o.File.Filename, os.O_WRONLY|os.O_CREATE, 0600) if err != nil { if os.IsPermission(err) { return append(msgs, "unable to write to log file: "+o.File.Filename) } } - file.Close() + err = file.Close() + if err != nil { + return append(msgs, "error closing the log file: "+o.File.Filename) + } logger.Printf("Redirecting logging to file: %s", o.File.Filename) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 21601d5a..86197c06 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -30,6 +30,8 @@ func Validate(o *options.Options) error { msgs = append(msgs, validateSessionCookieMinimal(o)...) if o.SSLInsecureSkipVerify { + // InsecureSkipVerify is a configurable option we allow + /* #nosec G402 */ insecureTransport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } @@ -217,7 +219,10 @@ func Validate(o *options.Options) error { } if len(o.TrustedIPs) > 0 && o.ReverseProxy { - fmt.Fprintln(os.Stderr, "WARNING: trusting of IPs with --reverse-proxy poses risks if a header spoofing attack is possible.") + _, err := fmt.Fprintln(os.Stderr, "WARNING: trusting of IPs with --reverse-proxy poses risks if a header spoofing attack is possible.") + if err != nil { + panic(err) + } } for i, ipStr := range o.TrustedIPs { diff --git a/providers/logingov.go b/providers/logingov.go index def6043b..c524741f 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -4,12 +4,9 @@ import ( "bytes" "context" "crypto/rsa" - "encoding/json" "errors" "fmt" - "io/ioutil" "math/rand" - "net/http" "net/url" "time" @@ -106,28 +103,12 @@ type loginGovCustomClaims struct { // checkNonce checks the nonce in the id_token func checkNonce(idToken string, p *LoginGovProvider) (err error) { token, err := jwt.ParseWithClaims(idToken, &loginGovCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - resp, myerr := http.Get(p.PubJWKURL.String()) - if myerr != nil { - return nil, myerr - } - if resp.StatusCode != 200 { - myerr = fmt.Errorf("got %d from %q", resp.StatusCode, p.PubJWKURL.String()) - return nil, myerr - } - body, myerr := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if myerr != nil { - return nil, myerr - } - var pubkeys jose.JSONWebKeySet - myerr = json.Unmarshal(body, &pubkeys) - if myerr != nil { - return nil, myerr + rerr := requests.New(p.PubJWKURL.String()).Do().UnmarshalInto(&pubkeys) + if rerr != nil { + return nil, rerr } - pubkey := pubkeys.Keys[0] - - return pubkey.Key, nil + return pubkeys.Keys[0].Key, nil }) if err != nil { return diff --git a/validator.go b/validator.go index e287ef50..9db34a27 100644 --- a/validator.go +++ b/validator.go @@ -3,6 +3,7 @@ package main import ( "encoding/csv" "fmt" + "io" "os" "strings" "sync/atomic" @@ -18,10 +19,12 @@ type UserMap struct { } // NewUserMap parses the authenticated emails file into a new UserMap +// +// TODO (@NickMeves): Audit usage of `unsafe.Pointer` and potentially refactor func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { um := &UserMap{usersFile: usersFile} m := make(map[string]bool) - atomic.StorePointer(&um.m, unsafe.Pointer(&m)) + atomic.StorePointer(&um.m, unsafe.Pointer(&m)) // #nosec G103 if usersFile != "" { logger.Printf("using authenticated emails file %s", usersFile) WatchForUpdates(usersFile, done, func() { @@ -47,7 +50,12 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() { if err != nil { logger.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) } - defer r.Close() + defer func(c io.Closer) { + cerr := c.Close() + if cerr != nil { + logger.Fatalf("Error closing authenticated emails file: %s", cerr) + } + }(r) csvReader := csv.NewReader(r) csvReader.Comma = ',' csvReader.Comment = '#' @@ -62,7 +70,7 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() { address := strings.ToLower(strings.TrimSpace(r[0])) updated[address] = true } - atomic.StorePointer(&um.m, unsafe.Pointer(&updated)) + atomic.StorePointer(&um.m, unsafe.Pointer(&updated)) // #nosec G103 } func newValidatorImpl(domains []string, usersFile string, diff --git a/watcher.go b/watcher.go index 7c916235..b71b8035 100644 --- a/watcher.go +++ b/watcher.go @@ -41,7 +41,12 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { logger.Fatal("failed to create watcher for ", filename, ": ", err) } go func() { - defer watcher.Close() + defer func(w *fsnotify.Watcher) { + cerr := w.Close() + if cerr != nil { + logger.Fatalf("error closing watcher: %v", err) + } + }(watcher) for { select { case <-done: @@ -55,7 +60,10 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { // can't be opened. if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 { logger.Printf("watching interrupted on event: %s", event) - watcher.Remove(filename) + err = watcher.Remove(filename) + if err != nil { + logger.Printf("error removing watcher on %s: %v", filename, err) + } WaitForReplacement(filename, event.Op, watcher) } logger.Printf("reloading after event: %s", event)