diff --git a/main.go b/main.go index a75a303b..ef9ac44e 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "os" "os/signal" "runtime" - "strings" "syscall" "time" @@ -62,20 +61,6 @@ func main() { os.Exit(1) } - if len(opts.Banner) >= 1 { - if opts.Banner == "-" { - oauthproxy.SignInMessage = "" - } else { - oauthproxy.SignInMessage = opts.Banner - } - } else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { - if len(opts.EmailDomains) > 1 { - oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) - } else if opts.EmailDomains[0] != "*" { - oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) - } - } - rand.Seed(time.Now().UnixNano()) chain := alice.New() diff --git a/oauthproxy.go b/oauthproxy.go index c7d1c1df..ac9e0565 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -213,6 +213,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr trustedIPs: trustedIPs, Banner: opts.Banner, Footer: opts.Footer, + SignInMessage: buildSignInMessage(opts), basicAuthValidator: basicAuthValidator, displayHtpasswdForm: basicAuthValidator != nil, @@ -255,6 +256,24 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt return chain } +func buildSignInMessage(opts *options.Options) string { + var msg string + if len(opts.Banner) >= 1 { + if opts.Banner == "-" { + msg = "" + } else { + msg = opts.Banner + } + } else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { + if len(opts.EmailDomains) > 1 { + msg = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) + } else if opts.EmailDomains[0] != "*" { + msg = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) + } + } + return msg +} + // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will // redirect clients to once authenticated func (p *OAuthProxy) GetRedirectURI(host string) string { diff --git a/watcher.go b/watcher.go index d5e02a02..b71b8035 100644 --- a/watcher.go +++ b/watcher.go @@ -44,7 +44,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { defer func(w *fsnotify.Watcher) { cerr := w.Close() if cerr != nil { - logger.Fatalf("error closing watcher: %s", err) + logger.Fatalf("error closing watcher: %v", err) } }(watcher) for { @@ -62,7 +62,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { logger.Printf("watching interrupted on event: %s", event) err = watcher.Remove(filename) if err != nil { - logger.Printf("error removing watcher on %s: %s", filename, err) + logger.Printf("error removing watcher on %s: %v", filename, err) } WaitForReplacement(filename, event.Op, watcher) }