Refactor SignInMessage out of main
This commit is contained in:
		
							parent
							
								
									46cc21d8cf
								
							
						
					
					
						commit
						e88d29f16a
					
				
							
								
								
									
										15
									
								
								main.go
								
								
								
								
							
							
						
						
									
										15
									
								
								main.go
								
								
								
								
							|  | @ -7,7 +7,6 @@ import ( | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/signal" | 	"os/signal" | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"strings" |  | ||||||
| 	"syscall" | 	"syscall" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | @ -62,20 +61,6 @@ func main() { | ||||||
| 		os.Exit(1) | 		os.Exit(1) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(opts.Banner) >= 1 { |  | ||||||
| 		if opts.Banner == "-" { |  | ||||||
| 			oauthproxy.SignInMessage = "" |  | ||||||
| 		} else { |  | ||||||
| 			oauthproxy.SignInMessage = opts.Banner |  | ||||||
| 		} |  | ||||||
| 	} else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { |  | ||||||
| 		if len(opts.EmailDomains) > 1 { |  | ||||||
| 			oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) |  | ||||||
| 		} else if opts.EmailDomains[0] != "*" { |  | ||||||
| 			oauthproxy.SignInMessage = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) | 	rand.Seed(time.Now().UnixNano()) | ||||||
| 
 | 
 | ||||||
| 	chain := alice.New() | 	chain := alice.New() | ||||||
|  |  | ||||||
|  | @ -213,6 +213,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		trustedIPs:              trustedIPs, | 		trustedIPs:              trustedIPs, | ||||||
| 		Banner:                  opts.Banner, | 		Banner:                  opts.Banner, | ||||||
| 		Footer:                  opts.Footer, | 		Footer:                  opts.Footer, | ||||||
|  | 		SignInMessage:           buildSignInMessage(opts), | ||||||
| 
 | 
 | ||||||
| 		basicAuthValidator:  basicAuthValidator, | 		basicAuthValidator:  basicAuthValidator, | ||||||
| 		displayHtpasswdForm: basicAuthValidator != nil, | 		displayHtpasswdForm: basicAuthValidator != nil, | ||||||
|  | @ -255,6 +256,24 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt | ||||||
| 	return chain | 	return chain | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func buildSignInMessage(opts *options.Options) string { | ||||||
|  | 	var msg string | ||||||
|  | 	if len(opts.Banner) >= 1 { | ||||||
|  | 		if opts.Banner == "-" { | ||||||
|  | 			msg = "" | ||||||
|  | 		} else { | ||||||
|  | 			msg = opts.Banner | ||||||
|  | 		} | ||||||
|  | 	} else if len(opts.EmailDomains) != 0 && opts.AuthenticatedEmailsFile == "" { | ||||||
|  | 		if len(opts.EmailDomains) > 1 { | ||||||
|  | 			msg = fmt.Sprintf("Authenticate using one of the following domains: %v", strings.Join(opts.EmailDomains, ", ")) | ||||||
|  | 		} else if opts.EmailDomains[0] != "*" { | ||||||
|  | 			msg = fmt.Sprintf("Authenticate using %v", opts.EmailDomains[0]) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return msg | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
 | // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
 | ||||||
| // redirect clients to once authenticated
 | // redirect clients to once authenticated
 | ||||||
| func (p *OAuthProxy) GetRedirectURI(host string) string { | func (p *OAuthProxy) GetRedirectURI(host string) string { | ||||||
|  |  | ||||||
|  | @ -44,7 +44,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { | ||||||
| 		defer func(w *fsnotify.Watcher) { | 		defer func(w *fsnotify.Watcher) { | ||||||
| 			cerr := w.Close() | 			cerr := w.Close() | ||||||
| 			if cerr != nil { | 			if cerr != nil { | ||||||
| 				logger.Fatalf("error closing watcher: %s", err) | 				logger.Fatalf("error closing watcher: %v", err) | ||||||
| 			} | 			} | ||||||
| 		}(watcher) | 		}(watcher) | ||||||
| 		for { | 		for { | ||||||
|  | @ -62,7 +62,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { | ||||||
| 					logger.Printf("watching interrupted on event: %s", event) | 					logger.Printf("watching interrupted on event: %s", event) | ||||||
| 					err = watcher.Remove(filename) | 					err = watcher.Remove(filename) | ||||||
| 					if err != nil { | 					if err != nil { | ||||||
| 						logger.Printf("error removing watcher on %s: %s", filename, err) | 						logger.Printf("error removing watcher on %s: %v", filename, err) | ||||||
| 					} | 					} | ||||||
| 					WaitForReplacement(filename, event.Op, watcher) | 					WaitForReplacement(filename, event.Op, watcher) | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue