diff --git a/CHANGELOG.md b/CHANGELOG.md index 788e82c2..301efe04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,13 @@ ## Release Highlights +- Logging now uses Go's `log/slog`, adding configurable log levels and optional structured JSON output while keeping text output as the default. + ## Important Notes +- The new `--logging-level` flag controls log verbosity with `debug`, `info`, `warn`, and `error` levels. +- The new `--logging-format` flag supports `text` and `json`; `text` remains the default, and existing text template options continue to apply to text output. + ## Breaking Changes ## Changes since v7.15.3 diff --git a/contrib/oauth2-proxy.cfg.example b/contrib/oauth2-proxy.cfg.example index df16e4cc..1365cf62 100644 --- a/contrib/oauth2-proxy.cfg.example +++ b/contrib/oauth2-proxy.cfg.example @@ -22,6 +22,8 @@ # ] ## Logging configuration +#logging_level = "info" +#logging_format = "text" #logging_filename = "" #logging_max_size = 100 #logging_max_age = 7 @@ -29,10 +31,10 @@ #logging_compress = false #standard_logging = true #standard_logging_format = "[{{.Timestamp}}] [{{.File}}] {{.Message}}" -#request_logging = true -#request_logging_format = "{{.Client}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" #auth_logging = true -#auth_logging_format = "{{.Client}} - {{.Username}} [{{.Timestamp}}] [{{.Status}}] {{.Message}}" +#auth_logging_format = "{{.Client}} - {{.RequestID}} - {{.Username}} [{{.Timestamp}}] [{{.Status}}] {{.Message}}" +#request_logging = true +#request_logging_format = "{{.Client}} - {{.RequestID}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" ## pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream # pass_basic_auth = true diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index 965953fa..f62b0764 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -154,22 +154,24 @@ Provider specific options can be found on their respective subpages. | Flag / Config Field | Type | Description | Default | | --------------------------------------------------------------------- | ------ | ---------------------------------------------------------------------------- | --------------------------------------------------- | -| flag: `--auth-logging-format`
toml: `auth_logging_format` | string | Template for authentication log lines | see [Logging Configuration](#logging-configuration) | +| flag: `--logging-level`
toml: `logging_level` | string | Log level: `debug`, `info`, `warn`, `error` | `"info"` | +| flag: `--logging-format`
toml: `logging_format` | string | Log format: `text`, `json` | `"text"` | +| flag: `--auth-logging-format`
toml: `auth_logging_format` | string | Template for authentication log lines when `logging_format` is `text` | see [Logging Configuration](#logging-configuration) | | flag: `--auth-logging`
toml: `auth_logging` | bool | Log authentication attempts | true | -| flag: `--errors-to-info-log`
toml: `errors_to_info_log` | bool | redirects error-level logging to default log channel instead of stderr | false | -| flag: `--exclude-logging-path`
toml: `exclude_logging_paths` | string | comma separated list of paths to exclude from logging, e.g. `"/ping,/path2"` | `""` (no paths excluded) | -| flag: `--logging-compress`
toml: `logging_compress` | bool | Should rotated log files be compressed using gzip | false | +| flag: `--request-logging-format`
toml: `request_logging_format` | string | Template for request log lines when `logging_format` is `text` | see [Logging Configuration](#logging-configuration) | +| flag: `--request-logging`
toml: `request_logging` | bool | Log HTTP requests | true | +| flag: `--standard-logging-format`
toml: `standard_logging_format` | string | Template for standard log lines when `logging_format` is `text` | see [Logging Configuration](#logging-configuration) | +| flag: `--standard-logging`
toml: `standard_logging` | bool | Log standard runtime information | true | +| flag: `--errors-to-info-log`
toml: `errors_to_info_log` | bool | Redirect error-level logging to the standard log channel instead of stderr | false | +| flag: `--exclude-logging-path`
toml: `exclude_logging_paths` | string | Comma separated list of paths to exclude from logging, e.g. `"/ping,/path2"` | `""` (no paths excluded) | +| flag: `--silence-ping-logging`
toml: `silence_ping_logging` | bool | Disable logging of requests to ping & ready endpoints | false | +| flag: `--request-id-header`
toml: `request_id_header` | string | Request header to use as the request ID in logging | `"X-Request-Id"` | | flag: `--logging-filename`
toml: `logging_filename` | string | File to log requests to, empty for `stdout` | `""` (stdout) | -| flag: `--logging-local-time`
toml: `logging_local_time` | bool | Use local time in log files and backup filenames instead of UTC | true (local time) | +| flag: `--logging-max-size`
toml: `logging_max_size` | int | Maximum size in megabytes of the log file before rotation | 100 | | flag: `--logging-max-age`
toml: `logging_max_age` | int | Maximum number of days to retain old log files | 7 | | flag: `--logging-max-backups`
toml: `logging_max_backups` | int | Maximum number of old log files to retain; 0 to disable | 0 | -| flag: `--logging-max-size`
toml: `logging_max_size` | int | Maximum size in megabytes of the log file before rotation | 100 | -| flag: `--request-id-header`
toml: `request_id_header` | string | Request header to use as the request ID in logging | X-Request-Id | -| flag: `--request-logging-format`
toml: `request_logging_format` | string | Template for request log lines | see [Logging Configuration](#logging-configuration) | -| flag: `--request-logging`
toml: `request_logging` | bool | Log requests | true | -| flag: `--silence-ping-logging`
toml: `silence_ping_logging` | bool | disable logging of requests to ping & ready endpoints | false | -| flag: `--standard-logging-format`
toml: `standard_logging_format` | string | Template for standard log lines | see [Logging Configuration](#logging-configuration) | -| flag: `--standard-logging`
toml: `standard_logging` | bool | Log standard runtime information | true | +| flag: `--logging-local-time`
toml: `logging_local_time` | bool | Use local time in log files and backup filenames instead of UTC | true (local time) | +| flag: `--logging-compress`
toml: `logging_compress` | bool | Should rotated log files be compressed using gzip | false | ### Page Template Options @@ -349,19 +351,47 @@ Please check the type for each [config option](#config-options) first. ## Logging Configuration -By default, OAuth2 Proxy logs all output to stdout. Logging can be configured to output to a rotating log file using the `--logging-filename` command. +OAuth2 Proxy uses Go's standard `log/slog` package for log levels and structured logging support. The default output format is `text`, preserving the existing template-based log lines. Set `--logging-format=json` to emit structured JSON logs for log aggregation systems. + +By default, OAuth2 Proxy logs standard informational output to stdout and standard warning/error output to stderr. Logging can be configured to output to a rotating log file using the `--logging-filename` command. If logging to a file you can also configure the maximum file size (`--logging-max-size`), age (`--logging-max-age`), max backup logs (`--logging-max-backups`), and if backup logs should be compressed (`--logging-compress`). There are three different types of logging: standard, authentication, and HTTP requests. These can each be enabled or disabled with `--standard-logging`, `--auth-logging`, and `--request-logging`. -Each type of logging has its own configurable format and variables. By default, these formats are similar to the Apache Combined Log. +### Log Format -Logging of requests to the `/ping` endpoint (or using `--ping-user-agent`) and the `/ready` endpoint can be disabled with `--silence-ping-logging` reducing log volume. +Two log formats are supported via `--logging-format`: -## Auth Log Format +- **`text`** (default): Human-readable log lines using the standard, authentication, and request templates documented below. +- **`json`**: Machine-readable JSON, one object per line. The template format options are ignored in JSON mode because fields are emitted as structured attributes. -Authentication logs are logs which are guaranteed to contain a username or email address of a user attempting to authenticate. These logs are output by default in the below format: +Example text output: + +``` +127.0.0.1:59742 - 6cccb6ca - user@example.com [2025/01/15 10:30:00] app.example.com GET - "/oauth2/callback" HTTP/1.1 "Mozilla/5.0" 200 12 0.001 +``` + +Example JSON output: + +```json +{"time":"2025-01-15T10:30:00Z","level":"INFO","source":{"function":"main.main","file":"main.go","line":42},"msg":"request","user":"user@example.com","client":"127.0.0.1:59742","host":"app.example.com","method":"GET","uri":"/oauth2/callback","protocol":"HTTP/1.1","status_code":200,"response_size":12,"duration_s":"0.001","request_id":"6cccb6ca"} +``` + +### Log Levels + +Four levels are supported via `--logging-level`: + +| Level | Description | +| ------- | ------------------------------------------------------------------------ | +| `debug` | Verbose output for troubleshooting | +| `info` | Normal operational messages (default) | +| `warn` | Warning conditions, including explicit authentication failures | +| `error` | Error conditions written to stderr unless `--errors-to-info-log` is set | + +### Auth Log Format + +Authentication logs are logs which are guaranteed to contain a username or email address of a user attempting to authenticate. In `text` format, these logs are output by default in the below format: ``` - - [2015/03/19 17:20:19] [] @@ -373,8 +403,7 @@ The status block will contain one of the below strings: - `AuthFailure` If the user failed to authenticate explicitly - `AuthError` If there was an unexpected error during authentication -If you require a different format than that, you can configure it with the `--auth-logging-format` flag. -The default format is configured as follows: +If you require a different text format, you can configure it with the `--auth-logging-format` flag. The default format is configured as follows: ``` {{.Client}} - {{.RequestID}} - {{.Username}} [{{.Timestamp}}] [{{.Status}}] {{.Message}} @@ -395,16 +424,15 @@ Available variables for auth logging: | Username | username@email.com | The email or username of the auth request. | | Status | AuthSuccess | The status of the auth request. See above for details. | -## Request Log Format +### Request Log Format -HTTP request logs will output by default in the below format: +HTTP request logs will output by default in the below text format: ``` - - [2015/03/19 17:20:19] GET "/path/" HTTP/1.1 "" ``` -If you require a different format than that, you can configure it with the `--request-logging-format` flag. -The default format is configured as follows: +If you require a different text format, you can configure it with the `--request-logging-format` flag. The default format is configured as follows: ``` {{.Client}} - {{.RequestID}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}} @@ -428,15 +456,15 @@ Available variables for request logging: | UserAgent | - | The full user agent as reported by the requesting client. | | Username | username@email.com | The email or username of the auth request. | -## Standard Log Format +### Standard Log Format -All other logging that is not covered by the above two types of logging will be output in this standard logging format. This includes configuration information at startup and errors that occur outside of a session. The default format is below: +All other logging that is not covered by the above two types of logging will be output in this standard logging format. This includes configuration information at startup and errors that occur outside of a session. The default text format is below: ``` [2015/03/19 17:20:19] [main.go:40] ``` -If you require a different format than that, you can configure it with the `--standard-logging-format` flag. The default format is configured as follows: +If you require a different text format, you can configure it with the `--standard-logging-format` flag. The default format is configured as follows: ``` [{{.Timestamp}}] [{{.File}}] {{.Message}} @@ -449,3 +477,47 @@ Available variables for standard logging: | Timestamp | 2015/03/19 17:20:19 | The date and time of the logging event. | | File | main.go:40 | The file and line number of the logging statement. | | Message | HTTP: listening on 127.0.0.1:4180 | The details of the log statement. | + +### JSON Fields + +When `--logging-format=json` is configured, log entries automatically include structured fields relevant to the log type. + +Authentication logs include: + +| Field | Description | +| ------------ | -------------------------------------------------------- | +| `user` | The email or username of the auth request | +| `status` | `AuthSuccess`, `AuthFailure`, or `AuthError` | +| `client` | The client/remote IP address | +| `request_id` | The request ID from `--request-id-header`, if available | +| `host` | The Host header value | +| `method` | The HTTP request method | +| `protocol` | The request protocol | +| `user_agent` | The client User-Agent header | + +Request logs include: + +| Field | Description | +| --------------- | ------------------------------------------------------- | +| `user` | The email or username | +| `upstream` | The upstream backend that handled the request | +| `client` | The client/remote IP address | +| `request_id` | The request ID from `--request-id-header`, if available | +| `host` | The Host header value | +| `method` | The HTTP request method | +| `uri` | The request URI path | +| `protocol` | The request protocol | +| `user_agent` | The client User-Agent header | +| `status_code` | The HTTP response status code | +| `response_size` | The response size in bytes | +| `duration_s` | The request duration in seconds | + +### Log Routing + +By default, standard log messages at `INFO` and below are written to stdout, while `WARN` and `ERROR` messages are written to stderr. Use `--errors-to-info-log` to redirect error output to stdout. + +When `--logging-filename` is configured, logs are written to the specified file with automatic rotation support via `--logging-max-size`, `--logging-max-age`, `--logging-max-backups`, and `--logging-compress`. + +### Filtering + +Logging of requests to the `/ping` and `/ready` endpoints can be disabled with `--silence-ping-logging` to reduce log volume. Additional paths can be excluded with `--exclude-logging-path`. diff --git a/main.go b/main.go index ba970679..94819dc5 100644 --- a/main.go +++ b/main.go @@ -14,7 +14,6 @@ import ( ) func main() { - logger.SetFlags(logger.Lshortfile) configFlagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ContinueOnError) @@ -35,21 +34,21 @@ func main() { } if *convertConfig && *alphaConfig != "" { - logger.Fatal("cannot use alpha-config and convert-config-to-alpha together") + logger.FatalMsg("cannot use alpha-config and convert-config-to-alpha together") } if *configTest && *convertConfig { - logger.Fatal("cannot use config-test and convert-config-to-alpha together") + logger.FatalMsg("cannot use config-test and convert-config-to-alpha together") } opts, err := loadConfiguration(*config, *alphaConfig, configFlagSet, os.Args[1:]) if err != nil { - logger.Fatalf("ERROR: %v", err) + logger.FatalMsg("failed to load configuration", "error", err) } if *configTest { if err = validation.Validate(opts); err != nil { - logger.Errorf("%s", err) + logger.ErrMsgf("%s", err) os.Exit(1) } fmt.Println("configuration is valid") @@ -58,23 +57,23 @@ func main() { if *convertConfig { if err := printConvertedConfig(opts); err != nil { - logger.Fatalf("ERROR: could not convert config: %v", err) + logger.FatalMsg("could not convert config", "error", err) } return } if err = validation.Validate(opts); err != nil { - logger.Fatalf("%s", err) + logger.FatalMsg("invalid configuration", "error", err) } validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) oauthproxy, err := NewOAuthProxy(opts, validator) if err != nil { - logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err) + logger.FatalMsg("failed to initialise OAuth2 Proxy", "error", err) } if err := oauthproxy.Start(); err != nil { - logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err) + logger.FatalMsg("failed to start OAuth2 Proxy", "error", err) } } @@ -88,7 +87,7 @@ func loadConfiguration(config, yamlConfig string, extraFlags *pflag.FlagSet, arg } if yamlConfig != "" { - logger.Printf("WARNING: You are using alpha configuration. The structure in this configuration file may change without notice. You MUST remove conflicting options from your existing configuration.") + logger.Warn("alpha configuration in use: the structure in this configuration file may change without notice, remove conflicting options from your existing configuration") opts, err = loadYamlOptions(yamlConfig, config, extraFlags, args) if err != nil { return nil, fmt.Errorf("failed to load yaml options: %w", err) diff --git a/main_suite_test.go b/main_suite_test.go index 360a0ee2..eb27b106 100644 --- a/main_suite_test.go +++ b/main_suite_test.go @@ -1,6 +1,7 @@ package main import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestMainSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Main Suite") diff --git a/oauthproxy.go b/oauthproxy.go index f8dc5471..5e1d4c5a 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -129,7 +129,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr var basicAuthValidator basic.Validator if opts.HtpasswdFile != "" { - logger.Printf("using htpasswd file: %s", opts.HtpasswdFile) + logger.Info("using htpasswd file", "path", opts.HtpasswdFile) var err error basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile) if err != nil { @@ -163,12 +163,12 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr } if opts.SkipJwtBearerTokens { - logger.Printf("Skipping JWT tokens from configured OIDC issuer: %q", opts.Providers[0].OIDCConfig.IssuerURL) + logger.Info("skipping JWT tokens from configured OIDC issuer", "issuer", opts.Providers[0].OIDCConfig.IssuerURL) for _, issuer := range opts.ExtraJwtIssuers { - logger.Printf("Skipping JWT tokens from extra JWT issuer: %q", issuer) + logger.Info("skipping JWT tokens from extra JWT issuer", "issuer", issuer) } if !opts.BearerTokenLoginFallback { - logger.Println("Denying requests with invalid JWT tokens") + logger.Info("denying requests with invalid JWT tokens") } } @@ -177,13 +177,22 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) } - logger.Printf("OAuthProxy configured for %s Client ID: %s", provider.Data().ProviderName, opts.Providers[0].ClientID) + logger.Info("OAuthProxy configured", "provider", provider.Data().ProviderName, "client_id", opts.Providers[0].ClientID) refresh := "disabled" if opts.Cookie.Refresh != time.Duration(0) { refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh) } - logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domains:%s path:%s samesite:%s refresh:%s", opts.Cookie.Name, opts.Cookie.Secure, opts.Cookie.HTTPOnly, opts.Cookie.Expire, strings.Join(opts.Cookie.Domains, ","), opts.Cookie.Path, opts.Cookie.SameSite, refresh) + logger.Info("cookie settings", + "name", opts.Cookie.Name, + "secure", opts.Cookie.Secure, + "httponly", opts.Cookie.HTTPOnly, + "expiry", opts.Cookie.Expire, + "domains", strings.Join(opts.Cookie.Domains, ","), + "path", opts.Cookie.Path, + "samesite", opts.Cookie.SameSite, + "refresh", refresh, + ) trustedIPs, err := ip.ParseNetSet(opts.TrustedIPs) if err != nil { @@ -372,7 +381,7 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt healthCheckPaths := []string{opts.PingPath} healthCheckUserAgents := []string{opts.PingUserAgent} if opts.GCPHealthChecks { - logger.Printf("WARNING: GCP HealthChecks are now deprecated: Reconfigure apps to use the ping path for liveness and readiness checks, set the ping user agent to \"GoogleHC/1.0\" to preserve existing behaviour") + logger.Warn("GCP HealthChecks are now deprecated: reconfigure apps to use the ping path for liveness and readiness checks, set the ping user agent to GoogleHC/1.0 to preserve existing behaviour") healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check") healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0") } @@ -401,7 +410,7 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt func buildTrustedProxyNetSet(opts *options.Options) (*ip.NetSet, error) { trustedProxyIPs := opts.TrustedProxyIPs if opts.ReverseProxy && len(trustedProxyIPs) == 0 { - logger.Print("WARNING: --reverse-proxy is enabled but no --trusted-proxy-ip CIDRs were configured. All connecting IPs are trusted to supply X-Forwarded-* headers by default (0.0.0.0/0, ::/0). This preserves backwards compatibility but is a potential security risk; configure --trusted-proxy-ip to match your reverse proxy addresses.") + logger.Warn("--reverse-proxy is enabled but no --trusted-proxy-ip CIDRs were configured. All connecting IPs are trusted to supply X-Forwarded-* headers by default (0.0.0.0/0, ::/0). This preserves backwards compatibility but is a potential security risk; configure --trusted-proxy-ip to match your reverse proxy addresses.") trustedProxyIPs = defaultTrustedProxyIPs } @@ -489,7 +498,7 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { if err != nil { return nil, err } - logger.Printf("Skipping auth - Method: ALL | Path: %s", path) + logger.Info("skipping auth", "method", "ALL", "path", path) routes = append(routes, allowedRoute{ method: "", pathRegex: compiledRegex, @@ -516,7 +525,7 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { if err != nil { return nil, err } - logger.Printf("Skipping auth - Method: %s | Path: %s", method, path) + logger.Info("skipping auth", "method", method, "path", path) routes = append(routes, allowedRoute{ method: method, negate: negate, @@ -536,7 +545,7 @@ func buildAPIRoutes(opts *options.Options) ([]apiRoute, error) { if err != nil { return nil, err } - logger.Printf("API route - Path: %s", path) + logger.Info("API route", "path", path) routes = append(routes, apiRoute{ pathRegex: compiledRegex, }) @@ -569,7 +578,7 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, appError string, messages ...interface{}) { redirectURL, err := p.appDirector.GetRedirect(req) if err != nil { - logger.Errorf("Error obtaining redirect: %v", err) + logger.ErrMsgf("error obtaining redirect: %v", err) } if redirectURL == p.SignInPath || redirectURL == "" { redirectURL = "/" @@ -632,7 +641,7 @@ func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) if err != nil { - logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) + logger.ErrMsgf("error obtaining real IP for trusted IP list: %v", err) // Possibly spoofed X-Real-IP header return false } @@ -649,13 +658,13 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code prepareNoCache(rw) if err := p.ClearSessionCookie(rw, req); err != nil { - logger.Printf("Error clearing session cookie: %v", err) + logger.ErrMsgf("error clearing session cookie: %v", err) } rw.WriteHeader(code) redirectURL, err := p.appDirector.GetRedirect(req) if err != nil { - logger.Errorf("Error obtaining redirect: %v", err) + logger.ErrMsgf("error obtaining redirect: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -679,10 +688,10 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool, int) { } // check auth if p.basicAuthValidator.Validate(user, passwd) { - logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via HtpasswdFile") + logger.LogAuth(user, req, logger.AuthSuccess, "authenticated via HtpasswdFile") return user, true, http.StatusOK } - logger.PrintAuthf(user, req, logger.AuthFailure, "Invalid authentication via HtpasswdFile") + logger.LogAuth(user, req, logger.AuthFailure, "invalid authentication via HtpasswdFile") return "", false, http.StatusUnauthorized } @@ -690,7 +699,7 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool, int) { func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { redirect, err := p.appDirector.GetRedirect(req) if err != nil { - logger.Errorf("Error obtaining redirect: %v", err) + logger.ErrMsgf("error obtaining redirect: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -700,7 +709,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { session := &sessionsapi.SessionState{User: user, Groups: p.basicAuthGroups} err = p.SaveSession(rw, req, session) if err != nil { - logger.Printf("Error saving session: %v", err) + logger.ErrMsgf("error saving session: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -727,7 +736,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) if session == nil { if _, err := rw.Write([]byte("{}")); err != nil { - logger.Printf("Error encoding empty user info: %v", err) + logger.ErrMsgf("error encoding empty user info: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) } return @@ -748,7 +757,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { } if err := json.NewEncoder(rw).Encode(userInfo); err != nil { - logger.Printf("Error encoding user info: %v", err) + logger.ErrMsgf("error encoding user info: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) } } @@ -757,7 +766,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { redirect, err := p.appDirector.GetRedirect(req) if err != nil { - logger.Errorf("Error obtaining redirect: %v", err) + logger.ErrMsgf("error obtaining redirect: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -765,7 +774,7 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { if strings.Contains(redirect, idTokenPlaceholder) { session, err := p.getAuthenticatedSession(rw, req) if err != nil { - logger.Errorf("error getting authenticated session during SignOut, won't replace id_token placeholder in redirect URL: %v", err) + logger.ErrMsgf("error getting authenticated session during SignOut, won't replace id_token placeholder in redirect URL: %v", err) } else { redirect = strings.ReplaceAll(redirect, idTokenPlaceholder, session.IDToken) } @@ -777,7 +786,7 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { err = p.ClearSessionCookie(rw, req) if err != nil { - logger.Errorf("Error clearing session cookie: %v", err) + logger.ErrMsgf("error clearing session cookie: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -788,7 +797,7 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) { session, err := p.getAuthenticatedSession(rw, req) if err != nil { - logger.Errorf("error getting authenticated session during backend logout: %v", err) + logger.ErrMsgf("error getting authenticated session during backend logout: %v", err) return } @@ -806,13 +815,13 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) { // base is not end-user provided but comes from configuration somewhat secure resp, err := http.Get(backendLogoutURL) // #nosec G107 if err != nil { - logger.Errorf("error while calling backend logout: %v", err) + logger.ErrMsgf("error while calling backend logout: %v", err) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { - logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode) + logger.ErrMsgf("error while calling backend logout url, returned error code %v", resp.StatusCode) } } @@ -834,14 +843,14 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove codeChallengeMethod = p.provider.Data().CodeChallengeMethod codeVerifier, err = encryption.GenerateCodeVerifierString(96) if err != nil { - logger.Errorf("Unable to build random ASCII string for code verifier: %v", err) + logger.ErrMsgf("unable to build random ASCII string for code verifier: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } codeChallenge, err = encryption.GenerateCodeChallenge(p.provider.Data().CodeChallengeMethod, codeVerifier) if err != nil { - logger.Errorf("Error creating code challenge: %v", err) + logger.ErrMsgf("error creating code challenge: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -852,14 +861,14 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove csrf, err := cookies.NewCSRF(p.CookieOptions, codeVerifier) if err != nil { - logger.Errorf("Error creating CSRF nonce: %v", err) + logger.ErrMsgf("error creating CSRF nonce: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } appRedirect, err := p.appDirector.GetRedirect(req) if err != nil { - logger.Errorf("Error obtaining application redirect: %v", err) + logger.ErrMsgf("error obtaining application redirect: %v", err) p.ErrorPage(rw, req, http.StatusBadRequest, err.Error()) return } @@ -873,7 +882,7 @@ func (p *OAuthProxy) doOAuthStart(rw http.ResponseWriter, req *http.Request, ove ) cookies.ClearExtraCsrfCookies(p.CookieOptions, rw, req) if _, err := csrf.SetCookie(rw, req); err != nil { - logger.Errorf("Error setting CSRF cookie: %v", err) + logger.ErrMsgf("error setting CSRF cookie: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -890,13 +899,13 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { // unlikely to be hit in practice. err := req.ParseForm() if err != nil { - logger.Errorf("Error while parsing OAuth2 callback: %v", err) + logger.ErrMsgf("error while parsing OAuth2 callback: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } errorString := req.Form.Get("error") if errorString != "" { - logger.Errorf("Error while parsing OAuth2 callback: %s", errorString) + logger.ErrMsg("error in OAuth2 callback from upstream identity provider", "error", errorString) message := fmt.Sprintf("Login Failed: The upstream identity provider returned an error: %s", errorString) // Set the debug message and override the non debug message to be the same for this case p.ErrorPage(rw, req, http.StatusForbidden, message, message) @@ -905,7 +914,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { nonce, appRedirect, err := decodeState(req.Form.Get("state"), p.encodeState) if err != nil { - logger.Errorf("Error while parsing OAuth2 state: %v", err) + logger.ErrMsgf("error while parsing OAuth2 state: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -918,21 +927,21 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { // There are a lot of issues opened complaining about missing CSRF cookies. // Try to log the INs and OUTs of OAuthProxy, to be easier to analyse these issues. LoggingCSRFCookiesInOAuthCallback(req, cookieName) - logger.Println(req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie: %s (state=%s)", err, nonce) + logger.Warn("invalid authentication via OAuth2: unable to obtain CSRF cookie", "error", err, "state", nonce) p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.") return } session, err := p.redeemCode(req, csrf.GetCodeVerifier()) if err != nil { - logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) + logger.ErrMsgf("error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } err = p.enrichSessionState(req.Context(), session) if err != nil { - logger.Errorf("Error creating session during OAuth2 callback: %v", err) + logger.ErrMsgf("error creating session during OAuth2 callback: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } @@ -940,14 +949,14 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { csrf.ClearCookie(rw, req) if !csrf.CheckOAuthState(nonce) { - logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: CSRF token mismatch, potential attack") + logger.LogAuth(session.Email, req, logger.AuthFailure, "invalid authentication via OAuth2: CSRF token mismatch, potential attack") p.ErrorPage(rw, req, http.StatusForbidden, "CSRF token mismatch, potential attack", "Login Failed: Unable to find a valid CSRF token. Please try again.") return } csrf.SetSessionNonce(session) if !p.provider.ValidateSession(req.Context(), session) { - logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Session validation failed: %s", session) + logger.LogAuth(session.Email, req, logger.AuthFailure, fmt.Sprintf("session validation failed: %s", session)) p.ErrorPage(rw, req, http.StatusForbidden, "Session validation failed") return } @@ -959,19 +968,19 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { // set cookie, or deny authorized, err := p.provider.Authorize(req.Context(), session) if err != nil { - logger.Errorf("Error with authorization: %v", err) + logger.ErrMsgf("error with authorization: %v", err) } if p.Validator(session.Email) && authorized { - logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) + logger.LogAuth(session.Email, req, logger.AuthSuccess, fmt.Sprintf("authenticated via OAuth2: %s", session)) err := p.SaveSession(rw, req, session) if err != nil { - logger.Errorf("Error saving session state for %s: %v", remoteAddr, err) + logger.ErrMsg("error saving session state", "remote_addr", remoteAddr, "error", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) return } http.Redirect(rw, req, appRedirect, http.StatusFound) } else { - logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unauthorized") + logger.LogAuth(session.Email, req, logger.AuthFailure, "invalid authentication via OAuth2: unauthorized") p.ErrorPage(rw, req, http.StatusForbidden, "Invalid session: unauthorized") } } @@ -1055,13 +1064,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { case ErrNeedsLogin: // we need to send the user to a login screen if p.forceJSONErrors || isAjax(req) || p.isAPIPath(req) { - logger.Printf("No valid authentication in request. Access Denied.") + logger.Warn("no valid authentication in request, access denied") // no point redirecting an AJAX request p.errorJSON(rw, http.StatusUnauthorized) return } - logger.Printf("No valid authentication in request. Initiating login.") + logger.Info("no valid authentication in request, initiating login") if p.SkipProviderButton { // start OAuth flow, but only with the default login URL params - do not // consider this request's query params as potential overrides, since @@ -1080,7 +1089,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { default: // unknown error - logger.Errorf("Unexpected internal error: %v", err) + logger.ErrMsgf("unexpected internal error: %v", err) p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) } } @@ -1154,7 +1163,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R invalidEmail := session.Email != "" && !p.Validator(session.Email) authorized, err := p.provider.Authorize(req.Context(), session) if err != nil { - logger.Errorf("Error with authorization: %v", err) + logger.ErrMsgf("error with authorization: %v", err) } if invalidEmail || !authorized { @@ -1163,11 +1172,11 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R cause = "invalid email" } - logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authorization via session (%s): removing session %s", cause, session) + logger.LogAuth(session.Email, req, logger.AuthFailure, fmt.Sprintf("invalid authorization via session (%s): removing session %s", cause, session)) // Invalid session, clear it err := p.ClearSessionCookie(rw, req) if err != nil { - logger.Errorf("Error clearing session cookie: %v", err) + logger.ErrMsgf("error clearing session cookie: %v", err) } return nil, ErrAccessDenied } @@ -1349,21 +1358,21 @@ func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) { func LoggingCSRFCookiesInOAuthCallback(req *http.Request, cookieName string) { cookies := req.Cookies() if len(cookies) == 0 { - logger.Println(req, logger.AuthFailure, "No cookies were found in OAuth callback.") + logger.Warn("no cookies were found in OAuth callback") return } for _, c := range cookies { if cookieName == c.Name { - logger.Println(req, logger.AuthFailure, "CSRF cookie %s was found in OAuth callback.", c.Name) + logger.Warn("CSRF cookie found in OAuth callback", "cookie_name", c.Name) return } if strings.HasSuffix(c.Name, "_csrf") { - logger.Println(req, logger.AuthFailure, "CSRF cookie %s was found in OAuth callback, but it is not the expected one (%s).", c.Name, cookieName) + logger.Warn("unexpected CSRF cookie found in OAuth callback", "cookie_name", c.Name, "expected", cookieName) return } } - logger.Println(req, logger.AuthFailure, "Cookies were found in OAuth callback, but none was a CSRF cookie.") + logger.Warn("cookies were found in OAuth callback, but none was a CSRF cookie") } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index b3271e5b..4da2c749 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -40,10 +40,6 @@ const ( clientSecret = "gv3498mfc9t23y23974dm2394dm9" ) -func init() { - logger.SetFlags(logger.Lshortfile) -} - func TestRobotsTxt(t *testing.T) { opts := baseTestOptions() err := validation.Validate(opts) @@ -178,7 +174,7 @@ func Test_enrichSession(t *testing.T) { func TestBasicAuthPassword(t *testing.T) { providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger.Printf("%#v", r) + logger.Infof("%#v", r) var payload string switch r.URL.Path { case "/oauth/token": diff --git a/pkg/apis/middleware/middleware_suite_test.go b/pkg/apis/middleware/middleware_suite_test.go index ff5f8f92..bf79ec86 100644 --- a/pkg/apis/middleware/middleware_suite_test.go +++ b/pkg/apis/middleware/middleware_suite_test.go @@ -1,6 +1,7 @@ package middleware_test import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -12,8 +13,7 @@ import ( // to prevent circular imports with the `logger` package which uses // this functionality func TestMiddlewareSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Middleware API") diff --git a/pkg/apis/options/cookie.go b/pkg/apis/options/cookie.go index 3dee9505..c6d8d86f 100644 --- a/pkg/apis/options/cookie.go +++ b/pkg/apis/options/cookie.go @@ -75,7 +75,7 @@ func (c *Cookie) GetSecret() (secret string, err error) { fileSecret, err := os.ReadFile(c.SecretFile) if err != nil { - logger.Errorf("error reading cookie secret file %s: %s", c.SecretFile, err) + logger.ErrMsgf("error reading cookie secret file %s: %s", c.SecretFile, err) return "", errors.New("could not read cookie secret file") } diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index e53fd480..2d3e6872 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -164,7 +164,7 @@ func (l *LegacyUpstreams) convert() (UpstreamConfig, error) { case "static": responseCode, err := strconv.Atoi(u.Host) if err != nil { - logger.Errorf("unable to convert %q to int, use default \"200\"", u.Host) + logger.ErrMsgf("unable to convert %q to int, use default \"200\"", u.Host) responseCode = 200 } upstream.Static = ptr.To(true) @@ -782,8 +782,8 @@ func (l *LegacyProvider) convert() (Providers, error) { case "google": if len(l.GoogleGroupsLegacy) != 0 && !reflect.DeepEqual(l.GoogleGroupsLegacy, l.GoogleGroups) { // Log the deprecation notice - logger.Error( - "WARNING: The 'OAUTH2_PROXY_GOOGLE_GROUP' environment variable is deprecated and will likely be removed in the next major release. Use 'OAUTH2_PROXY_GOOGLE_GROUPS' instead.", + logger.Warn( + "the 'OAUTH2_PROXY_GOOGLE_GROUP' environment variable is deprecated and will likely be removed in the next major release. Use 'OAUTH2_PROXY_GOOGLE_GROUPS' instead.", ) l.GoogleGroups = l.GoogleGroupsLegacy } diff --git a/pkg/apis/options/logging.go b/pkg/apis/options/logging.go index dfffd0fa..3d3f642d 100644 --- a/pkg/apis/options/logging.go +++ b/pkg/apis/options/logging.go @@ -7,6 +7,8 @@ import ( // Logging contains all options required for configuring the logging type Logging struct { + Level string `flag:"logging-level" cfg:"logging_level"` + Format string `flag:"logging-format" cfg:"logging_format"` AuthEnabled bool `flag:"auth-logging" cfg:"auth_logging"` AuthFormat string `flag:"auth-logging-format" cfg:"auth_logging_format"` RequestEnabled bool `flag:"request-logging" cfg:"request_logging"` @@ -33,13 +35,15 @@ type LogFileOptions struct { func loggingFlagSet() *pflag.FlagSet { flagSet := pflag.NewFlagSet("logging", pflag.ExitOnError) + flagSet.String("logging-level", "info", "Log level: debug, info, warn, error") + flagSet.String("logging-format", "text", "Log format: text, json") flagSet.Bool("auth-logging", true, "Log authentication attempts") - flagSet.String("auth-logging-format", logger.DefaultAuthLoggingFormat, "Template for authentication log lines") + flagSet.String("auth-logging-format", logger.DefaultAuthLoggingFormat, "Template for authentication log lines when logging-format=text") flagSet.Bool("standard-logging", true, "Log standard runtime information") - flagSet.String("standard-logging-format", logger.DefaultStandardLoggingFormat, "Template for standard log lines") + flagSet.String("standard-logging-format", logger.DefaultStandardLoggingFormat, "Template for standard log lines when logging-format=text") flagSet.Bool("request-logging", true, "Log HTTP requests") - flagSet.String("request-logging-format", logger.DefaultRequestLoggingFormat, "Template for HTTP request log lines") - flagSet.Bool("errors-to-info-log", false, "Log errors to the standard logging channel instead of stderr") + flagSet.String("request-logging-format", logger.DefaultRequestLoggingFormat, "Template for HTTP request log lines when logging-format=text") + flagSet.Bool("errors-to-info-log", false, "Log errors to the standard logging channel instead of stderr") flagSet.StringSlice("exclude-logging-path", []string{}, "Exclude logging requests to paths (eg: '/path1,/path2,/path3')") flagSet.Bool("logging-local-time", true, "If the time in log files and backup filenames are local or UTC time") @@ -58,6 +62,8 @@ func loggingFlagSet() *pflag.FlagSet { // loggingDefaults creates a Logging structure, populating each field with its default value func loggingDefaults() Logging { return Logging{ + Level: "info", + Format: "text", ExcludePaths: nil, LocalTime: true, SilencePing: false, diff --git a/pkg/apis/options/options_suite_test.go b/pkg/apis/options/options_suite_test.go index 247344e7..221d9671 100644 --- a/pkg/apis/options/options_suite_test.go +++ b/pkg/apis/options/options_suite_test.go @@ -1,6 +1,7 @@ package options import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestOptionsSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Options Suite") diff --git a/pkg/apis/options/util/util_suite_test.go b/pkg/apis/options/util/util_suite_test.go index 56cc8aad..7bddc0d7 100644 --- a/pkg/apis/options/util/util_suite_test.go +++ b/pkg/apis/options/util/util_suite_test.go @@ -1,6 +1,7 @@ package util import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestUtilSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Options Util Suite") diff --git a/pkg/app/pagewriter/error_page.go b/pkg/app/pagewriter/error_page.go index e62a9a52..b8f0f414 100644 --- a/pkg/app/pagewriter/error_page.go +++ b/pkg/app/pagewriter/error_page.go @@ -79,7 +79,7 @@ func (e *errorPageWriter) WriteErrorPage(rw http.ResponseWriter, opts ErrorPageO } if err := e.template.Execute(rw, data); err != nil { - logger.Printf("Error rendering error template: %v", err) + logger.ErrMsgf("error rendering error template: %v", err) http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } } @@ -88,7 +88,7 @@ func (e *errorPageWriter) WriteErrorPage(rw http.ResponseWriter, opts ErrorPageO // when there are issues with upstream servers. // It is expected to always render a bad gateway error. func (e *errorPageWriter) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { - logger.Errorf("Error proxying to upstream server: %v", proxyErr) + logger.ErrMsgf("error proxying to upstream server: %v", proxyErr) scope := middlewareapi.GetRequestScope(req) e.WriteErrorPage(rw, ErrorPageOpts{ Status: http.StatusBadGateway, diff --git a/pkg/app/pagewriter/pagewriter_suite_test.go b/pkg/app/pagewriter/pagewriter_suite_test.go index bf69c3df..aac79dd5 100644 --- a/pkg/app/pagewriter/pagewriter_suite_test.go +++ b/pkg/app/pagewriter/pagewriter_suite_test.go @@ -1,6 +1,7 @@ package pagewriter import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -11,8 +12,7 @@ import ( const testRequestID = "11111111-2222-4333-8444-555555555555" func TestOptionsSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "App Suite") diff --git a/pkg/app/pagewriter/sign_in_page.go b/pkg/app/pagewriter/sign_in_page.go index bb34bb49..0fd19974 100644 --- a/pkg/app/pagewriter/sign_in_page.go +++ b/pkg/app/pagewriter/sign_in_page.go @@ -79,7 +79,7 @@ func (s *signInPageWriter) WriteSignInPage(rw http.ResponseWriter, req *http.Req err := s.template.Execute(rw, t) if err != nil { - logger.Printf("Error rendering sign-in template: %v", err) + logger.ErrMsgf("error rendering sign-in template: %v", err) scope := middlewareapi.GetRequestScope(req) s.errorPageWriter.WriteErrorPage(rw, ErrorPageOpts{ Status: http.StatusInternalServerError, diff --git a/pkg/app/pagewriter/static_pages.go b/pkg/app/pagewriter/static_pages.go index 3e1bdcdc..873a77b6 100644 --- a/pkg/app/pagewriter/static_pages.go +++ b/pkg/app/pagewriter/static_pages.go @@ -34,7 +34,7 @@ func (s *staticPageWriter) WriteRobotsTxt(rw http.ResponseWriter, req *http.Requ func (s *staticPageWriter) writePage(rw http.ResponseWriter, req *http.Request, pageName string) { _, err := rw.Write(s.pageGetter.getPage(pageName)) if err != nil { - logger.Printf("Error writing %q: %v", pageName, err) + logger.ErrMsgf("error writing %q: %v", pageName, err) scope := middlewareapi.GetRequestScope(req) s.errorPageWriter.WriteErrorPage(rw, ErrorPageOpts{ Status: http.StatusInternalServerError, diff --git a/pkg/app/pagewriter/templates.go b/pkg/app/pagewriter/templates.go index 66d24ecc..e0c22c5d 100644 --- a/pkg/app/pagewriter/templates.go +++ b/pkg/app/pagewriter/templates.go @@ -60,7 +60,7 @@ func addTemplate(t *template.Template, customDir, fileName, defaultTemplate stri if err != nil { // This should not happen. // Default templates should be tested and so should never fail to parse. - logger.Panic("Could not parse defaultTemplate: ", err) + logger.PanicMsg("could not parse defaultTemplate", "error", err) } return t, nil } @@ -70,7 +70,7 @@ func addTemplate(t *template.Template, customDir, fileName, defaultTemplate stri func isFile(fileName string) bool { info, err := os.Stat(fileName) if err != nil { - logger.Errorf("Could not load file %s: %v, will use default template", fileName, err) + logger.ErrMsgf("could not load file %s: %v, will use default template", fileName, err) return false } return info.Mode().IsRegular() diff --git a/pkg/app/redirect/director.go b/pkg/app/redirect/director.go index 0ed8e408..d33843ee 100644 --- a/pkg/app/redirect/director.go +++ b/pkg/app/redirect/director.go @@ -82,7 +82,7 @@ func (a *appDirector) validateRedirect(redirect string, errorFormat string) stri return redirect } if redirect != "" { - logger.Errorf(errorFormat, redirect) + logger.ErrMsgf(errorFormat, redirect) } return "" } diff --git a/pkg/app/redirect/pagewriter_suite_test.go b/pkg/app/redirect/pagewriter_suite_test.go index d02ccc9e..dbe53d16 100644 --- a/pkg/app/redirect/pagewriter_suite_test.go +++ b/pkg/app/redirect/pagewriter_suite_test.go @@ -1,6 +1,7 @@ package redirect import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestOptionsSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Redirect Suite") diff --git a/pkg/app/redirect/validator.go b/pkg/app/redirect/validator.go index fd9074e2..eb5c9437 100644 --- a/pkg/app/redirect/validator.go +++ b/pkg/app/redirect/validator.go @@ -49,7 +49,7 @@ func (v *validator) IsValidRedirect(redirect string) bool { case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): redirectURL, err := url.Parse(redirect) if err != nil { - logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) + logger.Infof("rejecting invalid redirect %q: scheme unsupported or missing", redirect) return false } @@ -57,10 +57,10 @@ func (v *validator) IsValidRedirect(redirect string) bool { return true } - logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) + logger.Infof("rejecting invalid redirect %q: domain / port not in whitelist", redirect) return false default: - logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) + logger.Infof("rejecting invalid redirect %q: not an absolute or relative URL", redirect) return false } } diff --git a/pkg/authentication/basic/basic_suite_test.go b/pkg/authentication/basic/basic_suite_test.go index ca0acb07..c2661b73 100644 --- a/pkg/authentication/basic/basic_suite_test.go +++ b/pkg/authentication/basic/basic_suite_test.go @@ -1,6 +1,7 @@ package basic import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestBasicSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Basic") diff --git a/pkg/authentication/basic/htpasswd.go b/pkg/authentication/basic/htpasswd.go index edc2e6b0..d9774775 100644 --- a/pkg/authentication/basic/htpasswd.go +++ b/pkg/authentication/basic/htpasswd.go @@ -42,7 +42,7 @@ func NewHTPasswdValidator(path string) (Validator, error) { if err := watcher.WatchFileForUpdates(path, nil, func() { err := h.loadHTPasswdFile(path) if err != nil { - logger.Errorf("%v: no changes were made to the current htpasswd map", err) + logger.ErrMsgf("%v: no changes were made to the current htpasswd map", err) } }); err != nil { return nil, fmt.Errorf("could not watch htpasswd file: %v", err) @@ -61,7 +61,7 @@ func (h *htpasswdMap) loadHTPasswdFile(filename string) error { defer func(c io.Closer) { cerr := c.Close() if cerr != nil { - logger.Fatalf("error closing the htpasswd file: %v", cerr) + logger.FatalMsgf("error closing the htpasswd file: %v", cerr) } }(r) diff --git a/pkg/cookies/cookies.go b/pkg/cookies/cookies.go index be5ee451..d8e4ffdd 100644 --- a/pkg/cookies/cookies.go +++ b/pkg/cookies/cookies.go @@ -28,9 +28,9 @@ func MakeCookieFromOptions(req *http.Request, opts *CookieOptions) *http.Cookie domain := GetCookieDomain(req, opts.Domains) // If nothing matches, create the cookie with the shortest domain if domain == "" && len(opts.Domains) > 0 { - logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", - requestutil.GetRequestHost(req), - strings.Join(opts.Domains, ","), + logger.Warn("request host did not match any specific cookie domains", + "host", requestutil.GetRequestHost(req), + "domains", strings.Join(opts.Domains, ","), ) domain = opts.Domains[len(opts.Domains)-1] } @@ -96,6 +96,6 @@ func warnInvalidDomain(c *http.Cookie, req *http.Request) { host = h } if !strings.HasSuffix(host, c.Domain) { - logger.Errorf("Warning: request host is %q but using configured cookie domain of %q", host, c.Domain) + logger.Warn("request host does not match configured cookie domain", "host", host, "domain", c.Domain) } } diff --git a/pkg/cookies/cookies_suite_test.go b/pkg/cookies/cookies_suite_test.go index a11dd798..740bde6f 100644 --- a/pkg/cookies/cookies_suite_test.go +++ b/pkg/cookies/cookies_suite_test.go @@ -1,6 +1,7 @@ package cookies import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -25,7 +26,7 @@ const ( ) func TestProviderSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Cookies") diff --git a/pkg/header/header_suite_test.go b/pkg/header/header_suite_test.go index a1ee3768..6e4e3253 100644 --- a/pkg/header/header_suite_test.go +++ b/pkg/header/header_suite_test.go @@ -1,6 +1,7 @@ package header import ( + "log/slog" "os" "path" "testing" @@ -15,8 +16,7 @@ var ( ) func TestHeaderSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Header") diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index cb1da0a4..b5f4731c 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -2,8 +2,10 @@ package logger import ( "bytes" + "context" "fmt" "io" + "log/slog" "net/http" "net/url" "os" @@ -19,15 +21,15 @@ import ( // AuthStatus defines the different types of auth logging that occur type AuthStatus string -// Level indicates the log level for log messages +// Level indicates the log level for legacy logger callers. type Level int const ( - // DefaultStandardLoggingFormat defines the default standard log format + // DefaultStandardLoggingFormat defines the default standard log format. DefaultStandardLoggingFormat = "[{{.Timestamp}}] [{{.File}}] {{.Message}}" - // DefaultAuthLoggingFormat defines the default auth log format + // DefaultAuthLoggingFormat defines the default auth log format. DefaultAuthLoggingFormat = "{{.Client}} - {{.RequestID}} - {{.Username}} [{{.Timestamp}}] [{{.Status}}] {{.Message}}" - // DefaultRequestLoggingFormat defines the default request log format + // DefaultRequestLoggingFormat defines the default request log format. DefaultRequestLoggingFormat = "{{.Client}} - {{.RequestID}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" // AuthSuccess indicates that an auth attempt has succeeded explicitly @@ -36,20 +38,24 @@ const ( AuthFailure AuthStatus = "AuthFailure" // AuthError indicates that an auth attempt has failed due to an error AuthError AuthStatus = "AuthError" +) - // Llongfile flag to log full file name and line number: /a/b/c/d.go:23 - Llongfile = 1 << iota - // Lshortfile flag to log final file name element and line number: d.go:23. overrides Llongfile - Lshortfile - // LUTC flag to log UTC datetime rather than the local time zone - LUTC - // LstdFlags flag for initial values for the logger +const ( + // Llongfile logs the full file name and line number: /a/b/c/d.go:23. + Llongfile = 1 << 6 + // Lshortfile logs the final file name element and line number: d.go:23. It overrides Llongfile. + Lshortfile = 1 << 7 + // LUTC logs UTC timestamps rather than local time. + LUTC = 1 << 8 + // LstdFlags is the initial value for the logger flags. LstdFlags = Lshortfile +) - // DEFAULT is the default log level (effectively INFO) - DEFAULT Level = iota - // ERROR is for error-level logging - ERROR +const ( + // DEFAULT is the default legacy log level, effectively info. + DEFAULT Level = 10 + // ERROR is the legacy error log level. + ERROR Level = 11 ) // These are the containers for all values that are available as variables in the logging formats. @@ -89,147 +95,566 @@ type reqLogMessageData struct { Username string } -// Returns the apparent "real client IP" as a string. +// GetClientFunc returns the apparent "real client IP" as a string. type GetClientFunc = func(r *http.Request) string -// A Logger represents an active logging object that generates lines of -// output to an io.Writer passed through a formatter. Each logging -// operation makes a single call to the Writer's Write method. A Logger -// can be used simultaneously from multiple goroutines; it guarantees to -// serialize access to the Writer. -type Logger struct { - mu sync.Mutex - flag int - writer io.Writer - errWriter io.Writer - stdEnabled bool - authEnabled bool - reqEnabled bool - getClientFunc GetClientFunc - excludePaths map[string]struct{} - stdLogTemplate *template.Template - authTemplate *template.Template - reqTemplate *template.Template +// exitFunc is the function called by Fatal. Override in tests. +var exitFunc = os.Exit + +// Package-level state +var ( + mu sync.RWMutex + logLevel = new(slog.LevelVar) + defaultLogger = slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: logLevel, AddSource: true})) + writer io.Writer = os.Stdout + errWriter io.Writer = os.Stderr + logFormat = "text" + errToInfo = false + standardEnabled = true + localTime = true + flags = LstdFlags + stdLogTemplate = template.Must(template.New("std-log").Parse(DefaultStandardLoggingFormat)) + authTemplate = template.Must(template.New("auth-log").Parse(DefaultAuthLoggingFormat)) + reqTemplate = template.Must(template.New("req-log").Parse(DefaultRequestLoggingFormat)) + + getClientFunc GetClientFunc = func(r *http.Request) string { return r.RemoteAddr } + excludePaths map[string]struct{} + authEnabled = true + reqEnabled = true +) + +func init() { + logLevel.Set(slog.LevelInfo) + defaultLogger = slog.New(newLevelSplitHandler(logFormat, writer, errWriter)) + slog.SetDefault(defaultLogger) } -// New creates a new Standarderr Logger. -func New(flag int) *Logger { - return &Logger{ - writer: os.Stdout, - errWriter: os.Stderr, - flag: flag, - stdEnabled: true, - authEnabled: true, - reqEnabled: true, - getClientFunc: func(r *http.Request) string { return r.RemoteAddr }, - excludePaths: nil, - stdLogTemplate: template.Must(template.New("std-log").Parse(DefaultStandardLoggingFormat)), - authTemplate: template.Must(template.New("auth-log").Parse(DefaultAuthLoggingFormat)), - reqTemplate: template.Must(template.New("req-log").Parse(DefaultRequestLoggingFormat)), +// Setup initializes the logger with the given level, format, and writers. +// format must be "json" or "text". +func Setup(level slog.Level, format string, w io.Writer, errW io.Writer) { + mu.Lock() + defer mu.Unlock() + + logLevel.Set(level) + writer = w + errWriter = errW + logFormat = format + + handler := newLevelSplitHandler(format, w, errW) + defaultLogger = slog.New(handler) + slog.SetDefault(defaultLogger) +} + +// SetLevel changes the log level at runtime. +func SetLevel(level slog.Level) { + mu.Lock() + defer mu.Unlock() + logLevel.Set(level) +} + +// GetLevel returns the current log level. +func GetLevel() slog.Level { + mu.RLock() + defer mu.RUnlock() + return logLevel.Level() +} + +// SetOutput changes the standard output writer and reconfigures the handler. +func SetOutput(w io.Writer) { + mu.Lock() + defer mu.Unlock() + writer = w + if errToInfo { + errWriter = w + } + handler := newLevelSplitHandler(logFormat, writer, errWriter) + defaultLogger = slog.New(handler) + slog.SetDefault(defaultLogger) +} + +// SetErrOutput changes the error output writer and reconfigures the handler. +func SetErrOutput(w io.Writer) { + mu.Lock() + defer mu.Unlock() + errWriter = w + handler := newLevelSplitHandler(logFormat, writer, errWriter) + defaultLogger = slog.New(handler) + slog.SetDefault(defaultLogger) +} + +// SetErrToInfo routes error-level logs to the standard writer instead of the error writer. +func SetErrToInfo(e bool) { + mu.Lock() + defer mu.Unlock() + errToInfo = e + ew := errWriter + if e { + ew = writer + } + handler := newLevelSplitHandler(logFormat, writer, ew) + defaultLogger = slog.New(handler) + slog.SetDefault(defaultLogger) +} + +// SetGetClientFunc sets the function which determines the apparent "real client IP". +func SetGetClientFunc(f GetClientFunc) { + mu.Lock() + defer mu.Unlock() + getClientFunc = f +} + +// SetExcludePaths sets the paths to exclude from request logging. +func SetExcludePaths(s []string) { + mu.Lock() + defer mu.Unlock() + excludePaths = make(map[string]struct{}) + for _, p := range s { + excludePaths[p] = struct{}{} } } -var std = New(LstdFlags) +// SetAuthEnabled enables or disables auth logging. +func SetAuthEnabled(e bool) { + mu.Lock() + defer mu.Unlock() + authEnabled = e +} -func (l *Logger) formatLogMessage(calldepth int, message string) []byte { - now := time.Now() - file := "???:0" +// SetReqEnabled enables or disables request logging. +func SetReqEnabled(e bool) { + mu.Lock() + defer mu.Unlock() + reqEnabled = e +} - if l.flag&(Lshortfile|Llongfile) != 0 { - file = l.GetFileLineString(calldepth + 1) +// SetStandardEnabled enables or disables standard runtime logging. +func SetStandardEnabled(e bool) { + mu.Lock() + defer mu.Unlock() + standardEnabled = e +} + +// SetStandardTemplate sets the template for standard text logging. +func SetStandardTemplate(t string) { + mu.Lock() + defer mu.Unlock() + stdLogTemplate = template.Must(template.New("std-log").Parse(t)) +} + +// SetAuthTemplate sets the template for authentication text logging. +func SetAuthTemplate(t string) { + mu.Lock() + defer mu.Unlock() + authTemplate = template.Must(template.New("auth-log").Parse(t)) +} + +// SetReqTemplate sets the template for request text logging. +func SetReqTemplate(t string) { + mu.Lock() + defer mu.Unlock() + reqTemplate = template.Must(template.New("req-log").Parse(t)) +} + +// SetLocalTime controls whether template text logs use local time or UTC. +func SetLocalTime(e bool) { + mu.Lock() + defer mu.Unlock() + localTime = e + if e { + flags &^= LUTC + } else { + flags |= LUTC + } +} + +// FormatTimestamp returns a formatted timestamp using the current text log time zone setting. +func FormatTimestamp(ts time.Time) string { + mu.RLock() + useLocalTime := localTime + mu.RUnlock() + + if !useLocalTime { + ts = ts.UTC() } - var logBuff = new(bytes.Buffer) - err := l.stdLogTemplate.Execute(logBuff, stdLogMessageData{ - Timestamp: FormatTimestamp(now), - File: file, - Message: message, + return ts.Format("2006/01/02 15:04:05") +} + +// Flags returns the output flags for the standard logger. +func Flags() int { + mu.RLock() + defer mu.RUnlock() + return flags +} + +// SetFlags sets the output flags for the standard logger. +func SetFlags(flag int) { + mu.Lock() + defer mu.Unlock() + flags = flag + localTime = flag&LUTC == 0 +} + +// ---------- Structured log functions ---------- + +// Debug logs a message at Debug level with optional structured key-value pairs. +func Debug(msg string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelDebug, 3, msg, args...) + return + } + + defaultLogger.Debug(msg, args...) +} + +// Debugf logs a formatted message at Debug level. +func Debugf(format string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelDebug, 3, fmt.Sprintf(format, args...)) + return + } + + defaultLogger.Debug(fmt.Sprintf(format, args...)) +} + +// Info logs a message at Info level with optional structured key-value pairs. +func Info(msg string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelInfo, 3, msg, args...) + return + } + + defaultLogger.Info(msg, args...) +} + +// Infof logs a formatted message at Info level. +func Infof(format string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelInfo, 3, fmt.Sprintf(format, args...)) + return + } + + defaultLogger.Info(fmt.Sprintf(format, args...)) +} + +// Warn logs a message at Warn level with optional structured key-value pairs. +func Warn(msg string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelWarn, 3, msg, args...) + return + } + + defaultLogger.Warn(msg, args...) +} + +// Warnf logs a formatted message at Warn level. +func Warnf(format string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelWarn, 3, fmt.Sprintf(format, args...)) + return + } + + defaultLogger.Warn(fmt.Sprintf(format, args...)) +} + +// ErrMsg logs a message at Error level with optional structured key-value pairs. +func ErrMsg(msg string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelError, 3, msg, args...) + return + } + + defaultLogger.Error(msg, args...) +} + +// ErrMsgf logs a formatted message at Error level. +func ErrMsgf(format string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelError, 3, fmt.Sprintf(format, args...)) + return + } + + defaultLogger.Error(fmt.Sprintf(format, args...)) +} + +// FatalMsg logs a message at Error level and then calls os.Exit(1). +func FatalMsg(msg string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelError, 3, msg, args...) + exitFunc(1) + return + } + + defaultLogger.Error(msg, args...) + exitFunc(1) +} + +// FatalMsgf logs a formatted message at Error level and then calls os.Exit(1). +func FatalMsgf(format string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelError, 3, fmt.Sprintf(format, args...)) + exitFunc(1) + return + } + + defaultLogger.Error(fmt.Sprintf(format, args...)) + exitFunc(1) +} + +// PanicMsg logs a message at Error level and then panics. +func PanicMsg(msg string, args ...any) { + if isTextFormat() { + logStandardText(slog.LevelError, 3, msg, args...) + panic(msg) + } + + defaultLogger.Error(msg, args...) + panic(msg) +} + +// PanicMsgf logs a formatted message at Error level and then panics. +func PanicMsgf(format string, args ...any) { + s := fmt.Sprintf(format, args...) + if isTextFormat() { + logStandardText(slog.LevelError, 3, s) + panic(s) + } + + defaultLogger.Error(s) + panic(s) +} + +func isTextFormat() bool { + mu.RLock() + defer mu.RUnlock() + return logFormat == "text" +} + +func logStandardText(level slog.Level, callerDepth int, msg string, args ...any) { + if level < logLevel.Level() { + return + } + + mu.Lock() + defer mu.Unlock() + + if !standardEnabled { + return + } + + target := writer + if level >= slog.LevelWarn && !errToInfo { + target = errWriter + } + + var logBuff bytes.Buffer + err := stdLogTemplate.Execute(&logBuff, stdLogMessageData{ + Timestamp: formatTimestamp(time.Now(), localTime), + File: sourceFileFromCaller(callerDepth, flags), + Message: messageWithAttrs(msg, args...), }) if err != nil { panic(err) } - _, err = logBuff.Write([]byte("\n")) - if err != nil { + if _, err = logBuff.Write([]byte("\n")); err != nil { panic(err) } - return logBuff.Bytes() -} - -// Output a standard log template with a simple message to default output channel. -// Write a final newline at the end of every message. -func (l *Logger) Output(lvl Level, calldepth int, message string) { - l.mu.Lock() - defer l.mu.Unlock() - if !l.stdEnabled { - return - } - msg := l.formatLogMessage(calldepth+1, message) - - var err error - switch lvl { - case ERROR: - _, err = l.errWriter.Write(msg) - default: - _, err = l.writer.Write(msg) - } - if err != nil { + if _, err = target.Write(logBuff.Bytes()); err != nil { panic(err) } } -// PrintAuthf writes auth info to the logger. Requires an http.Request to -// log request details. Remaining arguments are handled in the manner of -// fmt.Sprintf. Writes a final newline to the end of every message. -func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatus, format string, a ...interface{}) { - if !l.authEnabled { - return +func formatTimestamp(ts time.Time, useLocalTime bool) string { + if !useLocalTime { + ts = ts.UTC() } - now := time.Now() + return ts.Format("2006/01/02 15:04:05") +} + +func messageWithAttrs(msg string, args ...any) string { + if len(args) == 0 { + return msg + } + + var buf bytes.Buffer + buf.WriteString(msg) + for i := 0; i < len(args); i += 2 { + buf.WriteByte(' ') + if i+1 >= len(args) { + fmt.Fprint(&buf, args[i]) + continue + } + fmt.Fprint(&buf, args[i]) + buf.WriteByte('=') + fmt.Fprint(&buf, args[i+1]) + } + return buf.String() +} + +func logAuthText(username string, req *http.Request, status AuthStatus, msg string, args ...any) { + if username == "" { + username = "-" + } + + mu.Lock() + defer mu.Unlock() + + scope := middlewareapi.GetRequestScope(req) + err := authTemplate.Execute(writer, authLogMessageData{ + Client: getClientFunc(req), + Host: requestutil.GetRequestHost(req), + Protocol: req.Proto, + RequestID: scope.RequestID, + RequestMethod: req.Method, + Timestamp: formatTimestamp(time.Now(), localTime), + UserAgent: fmt.Sprintf("%q", req.UserAgent()), + Username: username, + Status: string(status), + Message: messageWithAttrs(msg, args...), + }) + if err != nil { + panic(err) + } + + if _, err = writer.Write([]byte("\n")); err != nil { + panic(err) + } +} + +func logRequestText(username, upstream string, req *http.Request, reqURL url.URL, ts time.Time, status int, size int) { + if username == "" { + username = "-" + } + + if upstream == "" { + upstream = "-" + } + + if reqURL.User != nil && username == "-" { + if name := reqURL.User.Username(); name != "" { + username = name + } + } + + duration := float64(time.Since(ts)) / float64(time.Second) + + mu.Lock() + defer mu.Unlock() + + scope := middlewareapi.GetRequestScope(req) + err := reqTemplate.Execute(writer, reqLogMessageData{ + Client: getClientFunc(req), + Host: requestutil.GetRequestHost(req), + Protocol: req.Proto, + RequestID: scope.RequestID, + RequestDuration: fmt.Sprintf("%0.3f", duration), + RequestMethod: req.Method, + RequestURI: fmt.Sprintf("%q", reqURL.RequestURI()), + ResponseSize: fmt.Sprintf("%d", size), + StatusCode: fmt.Sprintf("%d", status), + Timestamp: formatTimestamp(ts, localTime), + Upstream: upstream, + UserAgent: fmt.Sprintf("%q", req.UserAgent()), + Username: username, + }) + if err != nil { + panic(err) + } + + if _, err = writer.Write([]byte("\n")); err != nil { + panic(err) + } +} + +// ---------- Structured auth and request logging ---------- + +// LogAuth logs an authentication event with structured attributes. +// The log level is derived from the AuthStatus: +// - AuthSuccess → Info +// - AuthFailure → Warn +// - AuthError → Error +func LogAuth(username string, req *http.Request, status AuthStatus, msg string, args ...any) { + mu.RLock() + enabled := authEnabled + clientFunc := getClientFunc + format := logFormat + mu.RUnlock() + + if !enabled { + return + } if username == "" { username = "-" } - client := l.getClientFunc(req) - - l.mu.Lock() - defer l.mu.Unlock() - + client := clientFunc(req) scope := middlewareapi.GetRequestScope(req) - err := l.authTemplate.Execute(l.writer, authLogMessageData{ - Client: client, - Host: requestutil.GetRequestHost(req), - Protocol: req.Proto, - RequestID: scope.RequestID, - RequestMethod: req.Method, - Timestamp: FormatTimestamp(now), - UserAgent: fmt.Sprintf("%q", req.UserAgent()), - Username: username, - Status: string(status), - Message: fmt.Sprintf(format, a...), - }) - if err != nil { - panic(err) + + attrs := make([]any, 0, 16+len(args)) + attrs = append(attrs, + "user", username, + "client", client, + "host", requestutil.GetRequestHost(req), + "method", req.Method, + "protocol", req.Proto, + "user_agent", req.UserAgent(), + "request_id", scope.RequestID, + "status", string(status), + ) + attrs = append(attrs, args...) + + var level slog.Level + switch status { + case AuthSuccess: + level = slog.LevelInfo + case AuthFailure: + level = slog.LevelWarn + case AuthError: + level = slog.LevelError + default: + level = slog.LevelInfo } - _, err = l.writer.Write([]byte("\n")) - if err != nil { - panic(err) - } -} - -// PrintReq writes request details to the Logger using the http.Request, -// url, and timestamp of the request. Writes a final newline to the end -// of every message. -func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.URL, ts time.Time, status int, size int) { - if !l.reqEnabled { + if level < logLevel.Level() { return } - if _, ok := l.excludePaths[url.Path]; ok { + if format == "text" { + logAuthText(username, req, status, msg, args...) + return + } + + defaultLogger.Log(context.Background(), level, msg, attrs...) +} + +// LogRequest logs an HTTP request with structured attributes at Info level. +// It respects excludePaths and the reqEnabled flag. +func LogRequest(username, upstream string, req *http.Request, reqURL url.URL, ts time.Time, status int, size int) { + mu.RLock() + enabled := reqEnabled + excluded := excludePaths + clientFunc := getClientFunc + format := logFormat + mu.RUnlock() + + if !enabled { + return + } + + if slog.LevelInfo < logLevel.Level() { + return + } + + if _, ok := excluded[reqURL.Path]; ok { + return + } + + if format == "text" { + logRequestText(username, upstream, req, reqURL, ts, status, size) return } @@ -243,330 +668,183 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. upstream = "-" } - if url.User != nil && username == "-" { - if name := url.User.Username(); name != "" { + if reqURL.User != nil && username == "-" { + if name := reqURL.User.Username(); name != "" { username = name } } - client := l.getClientFunc(req) - - l.mu.Lock() - defer l.mu.Unlock() - + client := clientFunc(req) scope := middlewareapi.GetRequestScope(req) - err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ - Client: client, - Host: requestutil.GetRequestHost(req), - Protocol: req.Proto, - RequestID: scope.RequestID, - RequestDuration: fmt.Sprintf("%0.3f", duration), - RequestMethod: req.Method, - RequestURI: fmt.Sprintf("%q", url.RequestURI()), - ResponseSize: fmt.Sprintf("%d", size), - StatusCode: fmt.Sprintf("%d", status), - Timestamp: FormatTimestamp(ts), - Upstream: upstream, - UserAgent: fmt.Sprintf("%q", req.UserAgent()), - Username: username, - }) - if err != nil { - panic(err) + + defaultLogger.Info("request", + "user", username, + "client", client, + "host", requestutil.GetRequestHost(req), + "method", req.Method, + "uri", reqURL.RequestURI(), + "protocol", req.Proto, + "upstream", upstream, + "user_agent", req.UserAgent(), + "status_code", status, + "response_size", size, + "duration_s", fmt.Sprintf("%0.3f", duration), + "request_id", scope.RequestID, + ) +} + +// ---------- Level-split handler ---------- + +// levelSplitHandler routes log records to different writers based on level. +// Records at Warn level and above go to the error handler. +// Records below Warn go to the standard handler. +type levelSplitHandler struct { + stdHandler slog.Handler + errHandler slog.Handler +} + +func newLevelSplitHandler(format string, w io.Writer, errW io.Writer) *levelSplitHandler { + opts := &slog.HandlerOptions{ + Level: logLevel, + AddSource: true, + } + errOpts := &slog.HandlerOptions{ + Level: logLevel, + AddSource: true, } - _, err = l.writer.Write([]byte("\n")) - if err != nil { - panic(err) + var stdH, errH slog.Handler + switch format { + case "text": + stdH = newTemplateTextHandler(w) + errH = newTemplateTextHandler(errW) + default: // "json" + stdH = slog.NewJSONHandler(w, opts) + errH = slog.NewJSONHandler(errW, errOpts) + } + + return &levelSplitHandler{ + stdHandler: stdH, + errHandler: errH, } } -// GetFileLineString will find the caller file and line number -// taking in to account the calldepth to iterate up the stack -// to find the non-logging call location. -func (l *Logger) GetFileLineString(calldepth int) string { - var file string - var line int - var ok bool +func (h *levelSplitHandler) Enabled(_ context.Context, level slog.Level) bool { + return level >= logLevel.Level() +} - _, file, line, ok = runtime.Caller(calldepth) - if !ok { - file = "???" - line = 0 +func (h *levelSplitHandler) Handle(ctx context.Context, r slog.Record) error { + if r.Level >= slog.LevelWarn { + return h.errHandler.Handle(ctx, r) + } + return h.stdHandler.Handle(ctx, r) +} + +func (h *levelSplitHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &levelSplitHandler{ + stdHandler: h.stdHandler.WithAttrs(attrs), + errHandler: h.errHandler.WithAttrs(attrs), + } +} + +func (h *levelSplitHandler) WithGroup(name string) slog.Handler { + return &levelSplitHandler{ + stdHandler: h.stdHandler.WithGroup(name), + errHandler: h.errHandler.WithGroup(name), + } +} + +type templateTextHandler struct { + writer io.Writer + attrs []slog.Attr +} + +func newTemplateTextHandler(w io.Writer) *templateTextHandler { + return &templateTextHandler{writer: w} +} + +func (h *templateTextHandler) Enabled(_ context.Context, level slog.Level) bool { + return level >= logLevel.Level() +} + +func (h *templateTextHandler) Handle(_ context.Context, r slog.Record) error { + mu.Lock() + defer mu.Unlock() + + if !standardEnabled { + return nil } - if l.flag&Lshortfile != 0 { - short := file - for i := len(file) - 1; i > 0; i-- { - if file[i] == '/' { - short = file[i+1:] - break - } - } - file = short + var args []any + for _, attr := range h.attrs { + args = append(args, attr.Key, attr.Value.Any()) + } + r.Attrs(func(attr slog.Attr) bool { + args = append(args, attr.Key, attr.Value.Any()) + return true + }) + + var logBuff bytes.Buffer + err := stdLogTemplate.Execute(&logBuff, stdLogMessageData{ + Timestamp: formatTimestamp(r.Time, localTime), + File: sourceFile(r.PC, flags), + Message: messageWithAttrs(r.Message, args...), + }) + if err != nil { + return err + } + + if _, err = logBuff.Write([]byte("\n")); err != nil { + return err + } + + _, err = h.writer.Write(logBuff.Bytes()) + return err +} + +func (h *templateTextHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + next := &templateTextHandler{writer: h.writer} + next.attrs = append(next.attrs, h.attrs...) + next.attrs = append(next.attrs, attrs...) + return next +} + +func (h *templateTextHandler) WithGroup(_ string) slog.Handler { + return h +} + +func sourceFile(pc uintptr, flag int) string { + if pc == 0 { + return "???:0" + } + + frame, _ := runtime.CallersFrames([]uintptr{pc}).Next() + return formatFileLine(frame.File, frame.Line, flag) +} + +func sourceFileFromCaller(depth int, flag int) string { + _, file, line, ok := runtime.Caller(depth) + if !ok { + return "???:0" + } + + return formatFileLine(file, line, flag) +} + +func formatFileLine(file string, line int, flag int) string { + if flag&Lshortfile != 0 { + file = shortFile(file) } return fmt.Sprintf("%s:%d", file, line) } -// FormatTimestamp returns a formatted timestamp. -func (l *Logger) FormatTimestamp(ts time.Time) string { - if l.flag&LUTC != 0 { - ts = ts.UTC() +func shortFile(file string) string { + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + return file[i+1:] + } } - return ts.Format("2006/01/02 15:04:05") -} - -// Flags returns the output flags for the logger. -func (l *Logger) Flags() int { - l.mu.Lock() - defer l.mu.Unlock() - return l.flag -} - -// SetFlags sets the output flags for the logger. -func (l *Logger) SetFlags(flag int) { - l.mu.Lock() - defer l.mu.Unlock() - l.flag = flag -} - -// SetStandardEnabled enables or disables standard logging. -func (l *Logger) SetStandardEnabled(e bool) { - l.mu.Lock() - defer l.mu.Unlock() - l.stdEnabled = e -} - -// SetErrToInfo enables or disables error logging to error writer instead of the default. -func (l *Logger) SetErrToInfo(e bool) { - l.mu.Lock() - defer l.mu.Unlock() - if e { - l.errWriter = l.writer - } else { - l.errWriter = os.Stderr - } -} - -// SetAuthEnabled enables or disables auth logging. -func (l *Logger) SetAuthEnabled(e bool) { - l.mu.Lock() - defer l.mu.Unlock() - l.authEnabled = e -} - -// SetReqEnabled enabled or disables request logging. -func (l *Logger) SetReqEnabled(e bool) { - l.mu.Lock() - defer l.mu.Unlock() - l.reqEnabled = e -} - -// SetGetClientFunc sets the function which determines the apparent "real client IP". -func (l *Logger) SetGetClientFunc(f GetClientFunc) { - l.mu.Lock() - defer l.mu.Unlock() - l.getClientFunc = f -} - -// SetExcludePaths sets the paths to exclude from logging. -func (l *Logger) SetExcludePaths(s []string) { - l.mu.Lock() - defer l.mu.Unlock() - l.excludePaths = make(map[string]struct{}) - for _, p := range s { - l.excludePaths[p] = struct{}{} - } -} - -// SetStandardTemplate sets the template for standard logging. -func (l *Logger) SetStandardTemplate(t string) { - l.mu.Lock() - defer l.mu.Unlock() - l.stdLogTemplate = template.Must(template.New("std-log").Parse(t)) -} - -// SetAuthTemplate sets the template for auth logging. -func (l *Logger) SetAuthTemplate(t string) { - l.mu.Lock() - defer l.mu.Unlock() - l.authTemplate = template.Must(template.New("auth-log").Parse(t)) -} - -// SetReqTemplate sets the template for request logging. -func (l *Logger) SetReqTemplate(t string) { - l.mu.Lock() - defer l.mu.Unlock() - l.reqTemplate = template.Must(template.New("req-log").Parse(t)) -} - -// These functions utilize the standard logger. - -// FormatTimestamp returns a formatted timestamp for the standard logger. -func FormatTimestamp(ts time.Time) string { - return std.FormatTimestamp(ts) -} - -// Flags returns the output flags for the standard logger. -func Flags() int { - return std.Flags() -} - -// SetFlags sets the output flags for the standard logger. -func SetFlags(flag int) { - std.SetFlags(flag) -} - -// SetOutput sets the output destination for the standard logger's default channel. -func SetOutput(w io.Writer) { - std.mu.Lock() - defer std.mu.Unlock() - std.writer = w -} - -// SetErrOutput sets the output destination for the standard logger's error channel. -func SetErrOutput(w io.Writer) { - std.mu.Lock() - defer std.mu.Unlock() - std.errWriter = w -} - -// SetStandardEnabled enables or disables standard logging for the -// standard logger. -func SetStandardEnabled(e bool) { - std.SetStandardEnabled(e) -} - -// SetErrToInfo enables or disables error logging to output writer instead of -// error writer. -func SetErrToInfo(e bool) { - std.SetErrToInfo(e) -} - -// SetAuthEnabled enables or disables auth logging for the standard -// logger. -func SetAuthEnabled(e bool) { - std.SetAuthEnabled(e) -} - -// SetReqEnabled enables or disables request logging for the -// standard logger. -func SetReqEnabled(e bool) { - std.SetReqEnabled(e) -} - -// SetGetClientFunc sets the function which determines the apparent IP address -// set by a reverse proxy for the standard logger. -func SetGetClientFunc(f GetClientFunc) { - std.SetGetClientFunc(f) -} - -// SetExcludePaths sets the path to exclude from logging, eg: health checks -func SetExcludePaths(s []string) { - std.SetExcludePaths(s) -} - -// SetStandardTemplate sets the template for standard logging for -// the standard logger. -func SetStandardTemplate(t string) { - std.SetStandardTemplate(t) -} - -// SetAuthTemplate sets the template for auth logging for the -// standard logger. -func SetAuthTemplate(t string) { - std.SetAuthTemplate(t) -} - -// SetReqTemplate sets the template for request logging for the -// standard logger. -func SetReqTemplate(t string) { - std.SetReqTemplate(t) -} - -// Print calls Output to print to the standard logger. -// Arguments are handled in the manner of fmt.Print. -func Print(v ...interface{}) { - std.Output(DEFAULT, 2, fmt.Sprint(v...)) -} - -// Printf calls Output to print to the standard logger. -// Arguments are handled in the manner of fmt.Printf. -func Printf(format string, v ...interface{}) { - std.Output(DEFAULT, 2, fmt.Sprintf(format, v...)) -} - -// Println calls Output to print to the standard logger. -// Arguments are handled in the manner of fmt.Println. -func Println(v ...interface{}) { - std.Output(DEFAULT, 2, fmt.Sprintln(v...)) -} - -// Error calls OutputErr to print to the standard logger's error channel. -// Arguments are handled in the manner of fmt.Print. -func Error(v ...interface{}) { - std.Output(ERROR, 2, fmt.Sprint(v...)) -} - -// Errorf calls OutputErr to print to the standard logger's error channel. -// Arguments are handled in the manner of fmt.Printf. -func Errorf(format string, v ...interface{}) { - std.Output(ERROR, 2, fmt.Sprintf(format, v...)) -} - -// Errorln calls OutputErr to print to the standard logger's error channel. -// Arguments are handled in the manner of fmt.Println. -func Errorln(v ...interface{}) { - std.Output(ERROR, 2, fmt.Sprintln(v...)) -} - -// Fatal is equivalent to Print() followed by a call to os.Exit(1). -func Fatal(v ...interface{}) { - std.Output(ERROR, 2, fmt.Sprint(v...)) - os.Exit(1) -} - -// Fatalf is equivalent to Printf() followed by a call to os.Exit(1). -func Fatalf(format string, v ...interface{}) { - std.Output(ERROR, 2, fmt.Sprintf(format, v...)) - os.Exit(1) -} - -// Fatalln is equivalent to Println() followed by a call to os.Exit(1). -func Fatalln(v ...interface{}) { - std.Output(ERROR, 2, fmt.Sprintln(v...)) - os.Exit(1) -} - -// Panic is equivalent to Print() followed by a call to panic(). -func Panic(v ...interface{}) { - s := fmt.Sprint(v...) - std.Output(ERROR, 2, s) - panic(s) -} - -// Panicf is equivalent to Printf() followed by a call to panic(). -func Panicf(format string, v ...interface{}) { - s := fmt.Sprintf(format, v...) - std.Output(ERROR, 2, s) - panic(s) -} - -// Panicln is equivalent to Println() followed by a call to panic(). -func Panicln(v ...interface{}) { - s := fmt.Sprintln(v...) - std.Output(ERROR, 2, s) - panic(s) -} - -// PrintAuthf writes authentication details to the standard logger. -// Arguments are handled in the manner of fmt.Printf. -func PrintAuthf(username string, req *http.Request, status AuthStatus, format string, a ...interface{}) { - std.PrintAuthf(username, req, status, format, a...) -} - -// PrintReq writes request details to the standard logger. -func PrintReq(username, upstream string, req *http.Request, url url.URL, ts time.Time, status int, size int) { - std.PrintReq(username, upstream, req, url, ts, status, size) + return file } diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go new file mode 100644 index 00000000..57df0661 --- /dev/null +++ b/pkg/logger/logger_test.go @@ -0,0 +1,594 @@ +package logger + +import ( + "bytes" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "text/template" + "time" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" +) + +// resetLogger resets the logger to defaults for test isolation. +func resetLogger(t *testing.T) { + t.Helper() + logLevel.Set(slog.LevelInfo) + logFormat = "text" + standardEnabled = true + localTime = true + flags = LstdFlags + stdLogTemplate = template.Must(template.New("std-log").Parse(DefaultStandardLoggingFormat)) + authTemplate = template.Must(template.New("auth-log").Parse(DefaultAuthLoggingFormat)) + reqTemplate = template.Must(template.New("req-log").Parse(DefaultRequestLoggingFormat)) + authEnabled = true + reqEnabled = true + excludePaths = nil + getClientFunc = func(r *http.Request) string { return r.RemoteAddr } + errToInfo = false +} + +// parseJSON parses a JSON log line into a map. +func parseJSON(t *testing.T, data []byte) map[string]any { + t.Helper() + // Take only the first line if there are multiple + line := strings.TrimSpace(strings.Split(string(data), "\n")[0]) + if line == "" { + t.Fatal("empty log output") + } + var m map[string]any + if err := json.Unmarshal([]byte(line), &m); err != nil { + t.Fatalf("failed to parse JSON log: %v\nraw: %s", err, line) + } + return m +} + +func TestSetup_JSONFormat(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "json", buf, errBuf) + + Info("hello", "key", "val") + + m := parseJSON(t, buf.Bytes()) + if m["msg"] != "hello" { + t.Errorf("expected msg=hello, got %v", m["msg"]) + } + if m["key"] != "val" { + t.Errorf("expected key=val, got %v", m["key"]) + } + if m["level"] != "INFO" { + t.Errorf("expected level=INFO, got %v", m["level"]) + } + if _, ok := m["time"]; !ok { + t.Error("expected time field in JSON output") + } + if _, ok := m["source"]; !ok { + t.Error("expected source field in JSON output") + } +} + +func TestSetup_TextFormat(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "text", buf, errBuf) + + Info("hello", "key", "val") + + out := buf.String() + if strings.Contains(out, "level=INFO") { + t.Errorf("did not expect slog key=value text output, got: %s", out) + } + if !strings.Contains(out, "hello key=val") { + t.Errorf("expected message and attrs in template output, got: %s", out) + } + if !strings.HasSuffix(out, "\n") { + t.Errorf("expected text output to end with newline, got: %s", out) + } +} + +func TestTextFormat_StandardTemplate(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "text", buf, errBuf) + SetStandardTemplate("{{.Message}}") + + Info("hello", "key", "val") + + if got, want := buf.String(), "hello key=val\n"; got != want { + t.Errorf("expected custom standard template output %q, got %q", want, got) + } +} + +func TestTextFormat_StandardTemplateFileUsesCaller(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "text", buf, errBuf) + SetStandardTemplate("{{.File}}|{{.Message}}") + + Info("hello") + + out := buf.String() + if !strings.Contains(out, "logger_test.go:") { + t.Errorf("expected caller file in text output, got: %s", out) + } + if strings.Contains(out, "logger.go:") { + t.Errorf("expected external caller file, got logger wrapper file: %s", out) + } +} + +func TestSetFlags_ControlsStandardTemplateFile(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "text", buf, errBuf) + SetStandardTemplate("{{.File}}") + SetFlags(0) + + Info("hello") + + out := buf.String() + if !strings.Contains(out, "/pkg/logger/logger_test.go:") { + t.Errorf("expected full caller file when Lshortfile is disabled, got: %s", out) + } +} + +func TestTextFormat_StandardLoggingDisabled(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "text", buf, errBuf) + SetStandardEnabled(false) + + Info("hidden") + Warn("also hidden") + + if buf.Len() > 0 || errBuf.Len() > 0 { + t.Errorf("expected standard logs to be disabled, stdout=%q stderr=%q", buf.String(), errBuf.String()) + } +} + +func TestLevelFiltering(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "json", buf, errBuf) + + Debug("should not appear") + if buf.Len() > 0 { + t.Error("Debug message should be filtered at Info level") + } + + Info("should appear") + if buf.Len() == 0 { + t.Error("Info message should appear at Info level") + } + m := parseJSON(t, buf.Bytes()) + if m["msg"] != "should appear" { + t.Errorf("expected msg='should appear', got %v", m["msg"]) + } + + // Error should go to errBuf + ErrMsg("error msg") + if errBuf.Len() == 0 { + t.Error("Error message should appear in errBuf") + } +} + +func TestSetLevel_Runtime(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "json", buf, errBuf) + + Debug("hidden") + if buf.Len() > 0 { + t.Error("Debug should be hidden at Info level") + } + + SetLevel(slog.LevelDebug) + + Debug("visible") + if buf.Len() == 0 { + t.Error("Debug should be visible after SetLevel(Debug)") + } + m := parseJSON(t, buf.Bytes()) + if m["msg"] != "visible" { + t.Errorf("expected msg='visible', got %v", m["msg"]) + } +} + +func TestLevelSplitHandler(t *testing.T) { + resetLogger(t) + stdBuf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", stdBuf, errBuf) + + Info("info msg") + if stdBuf.Len() == 0 { + t.Error("Info should go to stdout") + } + if errBuf.Len() > 0 { + t.Error("Info should NOT go to stderr") + } + + stdBuf.Reset() + + Warn("warn msg") + if errBuf.Len() == 0 { + t.Error("Warn should go to stderr") + } + if stdBuf.Len() > 0 { + t.Error("Warn should NOT go to stdout") + } + + errBuf.Reset() + + ErrMsg("error msg") + if errBuf.Len() == 0 { + t.Error("Error should go to stderr") + } +} + +func TestErrToInfo(t *testing.T) { + resetLogger(t) + stdBuf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "json", stdBuf, errBuf) + + SetErrToInfo(true) + + ErrMsg("goes to stdout") + if stdBuf.Len() == 0 { + t.Error("Error should go to stdout when ErrToInfo is true") + } + if errBuf.Len() > 0 { + t.Error("Error should NOT go to stderr when ErrToInfo is true") + } +} + +func TestLogAuth_Success(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", buf, errBuf) + + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" + scope := &middlewareapi.RequestScope{RequestID: "test-request-id"} + req = middlewareapi.AddRequestScope(req, scope) + + LogAuth("user@test.com", req, AuthSuccess, "authenticated via OAuth2") + + m := parseJSON(t, buf.Bytes()) + if m["level"] != "INFO" { + t.Errorf("AuthSuccess should log at INFO, got %v", m["level"]) + } + if m["user"] != "user@test.com" { + t.Errorf("expected user=user@test.com, got %v", m["user"]) + } + if m["status"] != "AuthSuccess" { + t.Errorf("expected status=AuthSuccess, got %v", m["status"]) + } + if m["request_id"] != "test-request-id" { + t.Errorf("expected request_id=test-request-id, got %v", m["request_id"]) + } + if m["msg"] != "authenticated via OAuth2" { + t.Errorf("expected msg='authenticated via OAuth2', got %v", m["msg"]) + } +} + +func TestLogAuth_Failure(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", buf, errBuf) + + req := httptest.NewRequest("GET", "/test", nil) + scope := &middlewareapi.RequestScope{RequestID: "req-id"} + req = middlewareapi.AddRequestScope(req, scope) + + LogAuth("bad-user", req, AuthFailure, "invalid credentials") + + // AuthFailure → Warn → goes to errBuf + m := parseJSON(t, errBuf.Bytes()) + if m["level"] != "WARN" { + t.Errorf("AuthFailure should log at WARN, got %v", m["level"]) + } + if m["status"] != "AuthFailure" { + t.Errorf("expected status=AuthFailure, got %v", m["status"]) + } +} + +func TestLogAuth_Error(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", buf, errBuf) + + req := httptest.NewRequest("GET", "/test", nil) + scope := &middlewareapi.RequestScope{RequestID: "req-id"} + req = middlewareapi.AddRequestScope(req, scope) + + LogAuth("user", req, AuthError, "internal error") + + m := parseJSON(t, errBuf.Bytes()) + if m["level"] != "ERROR" { + t.Errorf("AuthError should log at ERROR, got %v", m["level"]) + } +} + +func TestLogAuth_Disabled(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", buf, errBuf) + SetAuthEnabled(false) + + req := httptest.NewRequest("GET", "/test", nil) + scope := &middlewareapi.RequestScope{RequestID: "req-id"} + req = middlewareapi.AddRequestScope(req, scope) + + LogAuth("user", req, AuthSuccess, "should not appear") + + if buf.Len() > 0 || errBuf.Len() > 0 { + t.Error("LogAuth should produce no output when auth logging is disabled") + } +} + +func TestLogAuth_TextTemplate(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "text", buf, errBuf) + SetAuthTemplate("{{.Client}}|{{.RequestID}}|{{.Username}}|{{.Status}}|{{.Message}}") + + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" + scope := &middlewareapi.RequestScope{RequestID: "req-id"} + req = middlewareapi.AddRequestScope(req, scope) + + LogAuth("user@test.com", req, AuthSuccess, "authenticated", "provider", "oidc") + + if got, want := buf.String(), "192.168.1.1:1234|req-id|user@test.com|AuthSuccess|authenticated provider=oidc\n"; got != want { + t.Errorf("expected custom auth template output %q, got %q", want, got) + } + if errBuf.Len() > 0 { + t.Errorf("expected text auth log to use standard writer, got stderr %q", errBuf.String()) + } +} + +func TestLogRequest(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", buf, errBuf) + + req := httptest.NewRequest("GET", "/foo/bar", nil) + req.RemoteAddr = "127.0.0.1:5678" + scope := &middlewareapi.RequestScope{RequestID: "req-123"} + req = middlewareapi.AddRequestScope(req, scope) + + reqURL := *req.URL + LogRequest("testuser", "backend", req, reqURL, time.Now(), 200, 1024) + + // LogRequest uses time.Now() internally so we can't easily test duration_s exactly. + // Just parse and check fields. + m := parseJSON(t, buf.Bytes()) + if m["msg"] != "request" { + t.Errorf("expected msg=request, got %v", m["msg"]) + } + if m["level"] != "INFO" { + t.Errorf("expected level=INFO, got %v", m["level"]) + } + if m["user"] != "testuser" { + t.Errorf("expected user=testuser, got %v", m["user"]) + } + if m["upstream"] != "backend" { + t.Errorf("expected upstream=backend, got %v", m["upstream"]) + } + if m["method"] != "GET" { + t.Errorf("expected method=GET, got %v", m["method"]) + } + // status_code comes as float64 from JSON + if sc, ok := m["status_code"].(float64); !ok || int(sc) != 200 { + t.Errorf("expected status_code=200, got %v", m["status_code"]) + } + if rs, ok := m["response_size"].(float64); !ok || int(rs) != 1024 { + t.Errorf("expected response_size=1024, got %v", m["response_size"]) + } + if _, ok := m["duration_s"]; !ok { + t.Error("expected duration_s field") + } + if m["request_id"] != "req-123" { + t.Errorf("expected request_id=req-123, got %v", m["request_id"]) + } +} + +func TestLogRequest_ExcludePaths(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", buf, errBuf) + SetExcludePaths([]string{"/healthz", "/ping"}) + + req := httptest.NewRequest("GET", "/healthz", nil) + scope := &middlewareapi.RequestScope{RequestID: "req-id"} + req = middlewareapi.AddRequestScope(req, scope) + + reqURL := *req.URL + LogRequest("user", "-", req, reqURL, time.Now(), 200, 0) + + if buf.Len() > 0 { + t.Error("LogRequest should not log excluded paths") + } +} + +func TestLogRequest_Disabled(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", buf, errBuf) + SetReqEnabled(false) + + req := httptest.NewRequest("GET", "/foo", nil) + scope := &middlewareapi.RequestScope{RequestID: "req-id"} + req = middlewareapi.AddRequestScope(req, scope) + + reqURL := *req.URL + LogRequest("user", "-", req, reqURL, time.Now(), 200, 0) + + if buf.Len() > 0 { + t.Error("LogRequest should produce no output when request logging is disabled") + } +} + +func TestLogRequest_TextTemplate(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "text", buf, errBuf) + SetReqTemplate("{{.Username}}|{{.Upstream}}|{{.RequestMethod}}|{{.RequestURI}}|{{.StatusCode}}|{{.ResponseSize}}|{{.RequestID}}") + + req := httptest.NewRequest("GET", "/foo/bar?x=1", nil) + req.RemoteAddr = "127.0.0.1:5678" + scope := &middlewareapi.RequestScope{RequestID: "req-123"} + req = middlewareapi.AddRequestScope(req, scope) + + reqURL := *req.URL + LogRequest("testuser", "backend", req, reqURL, time.Now(), 204, 42) + + if got, want := buf.String(), "testuser|backend|GET|\"/foo/bar?x=1\"|204|42|req-123\n"; got != want { + t.Errorf("expected custom request template output %q, got %q", want, got) + } +} + +func TestFatalMsg(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "json", buf, errBuf) + + // Override exitFunc to capture the exit code + var exitCode int + exitFunc = func(code int) { exitCode = code } + defer func() { exitFunc = nil }() // will panic on real exit if not restored + + FatalMsg("fatal error") + + if exitCode != 1 { + t.Errorf("expected exit code 1, got %d", exitCode) + } + if errBuf.Len() == 0 { + t.Error("Fatal should produce error-level output") + } + m := parseJSON(t, errBuf.Bytes()) + if m["level"] != "ERROR" { + t.Errorf("expected level=ERROR, got %v", m["level"]) + } + + // Restore + exitFunc = func(code int) {} +} + +func TestInfof(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "json", buf, errBuf) + + Infof("hello %s", "world") + + m := parseJSON(t, buf.Bytes()) + if m["msg"] != "hello world" { + t.Errorf("expected msg='hello world', got %v", m["msg"]) + } + if m["level"] != "INFO" { + t.Errorf("expected level=INFO, got %v", m["level"]) + } +} + +func TestErrMsgf(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "json", buf, errBuf) + + ErrMsgf("error: %s", "something") + + m := parseJSON(t, errBuf.Bytes()) + if m["msg"] != "error: something" { + t.Errorf("expected msg='error: something', got %v", m["msg"]) + } + if m["level"] != "ERROR" { + t.Errorf("expected level=ERROR, got %v", m["level"]) + } +} + +func TestLogAuth(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelDebug, "json", buf, errBuf) + + req := httptest.NewRequest("GET", "/test", nil) + scope := &middlewareapi.RequestScope{RequestID: "req-id"} + req = middlewareapi.AddRequestScope(req, scope) + + LogAuth("user@test.com", req, AuthSuccess, "authenticated via OAuth2") + + m := parseJSON(t, buf.Bytes()) + if m["msg"] != "authenticated via OAuth2" { + t.Errorf("expected msg='authenticated via OAuth2', got %v", m["msg"]) + } + if m["user"] != "user@test.com" { + t.Errorf("expected user=user@test.com, got %v", m["user"]) + } +} + +func TestSetOutput(t *testing.T) { + resetLogger(t) + buf1 := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + Setup(slog.LevelInfo, "json", buf1, errBuf) + + Info("before") + if buf1.Len() == 0 { + t.Error("expected output in buf1") + } + + buf2 := &bytes.Buffer{} + SetOutput(buf2) + buf1.Reset() + + Info("after") + if buf2.Len() == 0 { + t.Error("expected output in buf2 after SetOutput") + } + if buf1.Len() > 0 { + t.Error("buf1 should have no new output after SetOutput") + } +} + +func TestGetLevel(t *testing.T) { + resetLogger(t) + buf := &bytes.Buffer{} + Setup(slog.LevelWarn, "json", buf, buf) + + if GetLevel() != slog.LevelWarn { + t.Errorf("expected LevelWarn, got %v", GetLevel()) + } + + SetLevel(slog.LevelDebug) + if GetLevel() != slog.LevelDebug { + t.Errorf("expected LevelDebug, got %v", GetLevel()) + } +} diff --git a/pkg/middleware/basic_session.go b/pkg/middleware/basic_session.go index 71b822c0..cae84e11 100644 --- a/pkg/middleware/basic_session.go +++ b/pkg/middleware/basic_session.go @@ -50,7 +50,7 @@ func loadBasicAuthSession(validator basic.Validator, sessionGroups []string, pre session, err := getSession(validator, sessionGroups, req) if err != nil { - logger.Errorf("Error retrieving session from token in Authorization header: %v", err) + logger.ErrMsgf("error retrieving session from token in Authorization header: %v", err) } // Add the session to the scope if it was found @@ -75,12 +75,12 @@ func getBasicSession(validator basic.Validator, sessionGroups []string, req *htt } if validator.Validate(user, password) { - logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") + logger.LogAuth(user, req, logger.AuthSuccess, "authenticated via basic auth and HTpasswd File") return &sessionsapi.SessionState{User: user, Groups: sessionGroups}, nil } - logger.PrintAuthf(user, req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") + logger.LogAuth(user, req, logger.AuthFailure, "invalid authentication via basic auth: not in Htpasswd File") return nil, nil } diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go index 790eb8b2..5156dac5 100644 --- a/pkg/middleware/jwt_session.go +++ b/pkg/middleware/jwt_session.go @@ -51,7 +51,7 @@ func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { session, err := j.getJwtSession(req) if err != nil { - logger.Errorf("Error retrieving session from token in Authorization header: %v", err) + logger.ErrMsgf("error retrieving session from token in Authorization header: %v", err) if j.denyInvalidJWTs { http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden) return diff --git a/pkg/middleware/middleware_suite_test.go b/pkg/middleware/middleware_suite_test.go index 0e3a2c9d..0694c4b3 100644 --- a/pkg/middleware/middleware_suite_test.go +++ b/pkg/middleware/middleware_suite_test.go @@ -1,6 +1,7 @@ package middleware import ( + "log/slog" "net/http" "testing" @@ -11,8 +12,7 @@ import ( ) func TestMiddlewareSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Middleware") diff --git a/pkg/middleware/request_logger.go b/pkg/middleware/request_logger.go index e6ed9f21..3ec5522c 100644 --- a/pkg/middleware/request_logger.go +++ b/pkg/middleware/request_logger.go @@ -29,7 +29,7 @@ func requestLogger(next http.Handler) http.Handler { scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. - logger.PrintReq( + logger.LogRequest( getUser(scope), scope.Upstream, req, diff --git a/pkg/middleware/request_logger_test.go b/pkg/middleware/request_logger_test.go index f481adf7..ae497580 100644 --- a/pkg/middleware/request_logger_test.go +++ b/pkg/middleware/request_logger_test.go @@ -2,6 +2,8 @@ package middleware import ( "bytes" + "encoding/json" + "log/slog" "net/http" "net/http/httptest" @@ -12,23 +14,32 @@ import ( . "github.com/onsi/gomega" ) -const RequestLoggingFormatWithoutTime = "{{.Client}} - {{.RequestID}} - {{.Username}} [TIMELESS] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}" - var _ = Describe("Request logger suite", func() { + type expectedFields struct { + User string + Client string + Host string + Method string + URI string + Protocol string + Upstream string + StatusCode float64 + Size float64 + RequestID string + } + type requestLoggerTableInput struct { - Format string - ExpectedLogMessage string - Path string - ExcludePaths []string - Upstream string - Session *sessions.SessionState + Expected *expectedFields // nil means no output expected + Path string + ExcludePaths []string + Upstream string + Session *sessions.SessionState } DescribeTable("when service a request", func(in *requestLoggerTableInput) { buf := bytes.NewBuffer(nil) - logger.SetOutput(buf) - logger.SetReqTemplate(in.Format) + logger.Setup(slog.LevelDebug, "json", buf, buf) logger.SetExcludePaths(in.ExcludePaths) req, err := http.NewRequest("GET", in.Path, nil) @@ -45,79 +56,119 @@ var _ = Describe("Request logger suite", func() { handler := NewRequestLogger()(testUpstreamHandler(in.Upstream)) handler.ServeHTTP(httptest.NewRecorder(), req) - Expect(buf.String()).To(Equal(in.ExpectedLogMessage)) + if in.Expected == nil { + Expect(buf.String()).To(BeEmpty()) + return + } + + var logEntry map[string]interface{} + Expect(json.Unmarshal(buf.Bytes(), &logEntry)).To(Succeed()) + + Expect(logEntry).To(HaveKeyWithValue("level", "INFO")) + Expect(logEntry).To(HaveKeyWithValue("msg", "request")) + Expect(logEntry).To(HaveKey("time")) + Expect(logEntry).To(HaveKeyWithValue("user", in.Expected.User)) + Expect(logEntry).To(HaveKeyWithValue("client", in.Expected.Client)) + Expect(logEntry).To(HaveKeyWithValue("host", in.Expected.Host)) + Expect(logEntry).To(HaveKeyWithValue("method", in.Expected.Method)) + Expect(logEntry).To(HaveKeyWithValue("uri", in.Expected.URI)) + Expect(logEntry).To(HaveKeyWithValue("protocol", in.Expected.Protocol)) + Expect(logEntry).To(HaveKeyWithValue("upstream", in.Expected.Upstream)) + Expect(logEntry).To(HaveKeyWithValue("status_code", in.Expected.StatusCode)) + Expect(logEntry).To(HaveKeyWithValue("response_size", in.Expected.Size)) + Expect(logEntry).To(HaveKeyWithValue("request_id", in.Expected.RequestID)) + Expect(logEntry).To(HaveKey("duration_s")) }, Entry("standard request", &requestLoggerTableInput{ - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "127.0.0.1 - 11111111-2222-4333-8444-555555555555 - standard.user [TIMELESS] test-server GET standard \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", - Path: "/foo/bar", - ExcludePaths: []string{}, - Upstream: "standard", - Session: &sessions.SessionState{User: "standard.user"}, + Expected: &expectedFields{ + User: "standard.user", Client: "127.0.0.1", Host: "test-server", + Method: "GET", URI: "/foo/bar", Protocol: "HTTP/1.1", + Upstream: "standard", StatusCode: 200, Size: 4, + RequestID: "11111111-2222-4333-8444-555555555555", + }, + Path: "/foo/bar", + ExcludePaths: []string{}, + Upstream: "standard", + Session: &sessions.SessionState{User: "standard.user"}, }), Entry("with unrelated path excluded", &requestLoggerTableInput{ - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "127.0.0.1 - 11111111-2222-4333-8444-555555555555 - unrelated.exclusion [TIMELESS] test-server GET unrelated \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", - Path: "/foo/bar", - ExcludePaths: []string{"/ping"}, - Upstream: "unrelated", - Session: &sessions.SessionState{User: "unrelated.exclusion"}, + Expected: &expectedFields{ + User: "unrelated.exclusion", Client: "127.0.0.1", Host: "test-server", + Method: "GET", URI: "/foo/bar", Protocol: "HTTP/1.1", + Upstream: "unrelated", StatusCode: 200, Size: 4, + RequestID: "11111111-2222-4333-8444-555555555555", + }, + Path: "/foo/bar", + ExcludePaths: []string{"/ping"}, + Upstream: "unrelated", + Session: &sessions.SessionState{User: "unrelated.exclusion"}, }), Entry("with path as the sole exclusion", &requestLoggerTableInput{ - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "", - Path: "/foo/bar", - ExcludePaths: []string{"/foo/bar"}, + Expected: nil, + Path: "/foo/bar", + ExcludePaths: []string{"/foo/bar"}, }), Entry("ping path", &requestLoggerTableInput{ - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "127.0.0.1 - 11111111-2222-4333-8444-555555555555 - mr.ping [TIMELESS] test-server GET - \"/ping\" HTTP/1.1 \"\" 200 4 0.000\n", - Path: "/ping", - ExcludePaths: []string{}, - Upstream: "", - Session: &sessions.SessionState{User: "mr.ping"}, + Expected: &expectedFields{ + User: "mr.ping", Client: "127.0.0.1", Host: "test-server", + Method: "GET", URI: "/ping", Protocol: "HTTP/1.1", + Upstream: "-", StatusCode: 200, Size: 4, + RequestID: "11111111-2222-4333-8444-555555555555", + }, + Path: "/ping", + ExcludePaths: []string{}, + Upstream: "", + Session: &sessions.SessionState{User: "mr.ping"}, }), Entry("ping path but excluded", &requestLoggerTableInput{ - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "", - Path: "/ping", - ExcludePaths: []string{"/ping"}, - Upstream: "", - Session: &sessions.SessionState{User: "mr.ping"}, + Expected: nil, + Path: "/ping", + ExcludePaths: []string{"/ping"}, + Upstream: "", + Session: &sessions.SessionState{User: "mr.ping"}, }), Entry("ping path and excluded in list", &requestLoggerTableInput{ - Format: RequestLoggingFormatWithoutTime, - ExpectedLogMessage: "", - Path: "/ping", - ExcludePaths: []string{"/foo/bar", "/ping"}, + Expected: nil, + Path: "/ping", + ExcludePaths: []string{"/foo/bar", "/ping"}, }), - Entry("custom format", &requestLoggerTableInput{ - Format: "{{.RequestMethod}} {{.Username}} {{.Upstream}}", - ExpectedLogMessage: "GET custom.format custom\n", - Path: "/foo/bar", - ExcludePaths: []string{""}, - Upstream: "custom", - Session: &sessions.SessionState{User: "custom.format"}, + Entry("request with no session", &requestLoggerTableInput{ + Expected: &expectedFields{ + User: "-", Client: "127.0.0.1", Host: "test-server", + Method: "GET", URI: "/foo/bar", Protocol: "HTTP/1.1", + Upstream: "custom", StatusCode: 200, Size: 4, + RequestID: "11111111-2222-4333-8444-555555555555", + }, + Path: "/foo/bar", + ExcludePaths: []string{""}, + Upstream: "custom", }), - Entry("custom format with unrelated exclusion", &requestLoggerTableInput{ - Format: "{{.RequestMethod}} {{.Username}} {{.Upstream}}", - ExpectedLogMessage: "GET custom.format custom\n", - Path: "/foo/bar", - ExcludePaths: []string{"/ping"}, - Upstream: "custom", - Session: &sessions.SessionState{User: "custom.format"}, + Entry("request with user session", &requestLoggerTableInput{ + Expected: &expectedFields{ + User: "custom.format", Client: "127.0.0.1", Host: "test-server", + Method: "GET", URI: "/foo/bar", Protocol: "HTTP/1.1", + Upstream: "custom", StatusCode: 200, Size: 4, + RequestID: "11111111-2222-4333-8444-555555555555", + }, + Path: "/foo/bar", + ExcludePaths: []string{"/ping"}, + Upstream: "custom", + Session: &sessions.SessionState{User: "custom.format"}, }), - Entry("custom format ping path", &requestLoggerTableInput{ - Format: "{{.RequestMethod}}", - ExpectedLogMessage: "GET\n", - Path: "/ping", - ExcludePaths: []string{""}, + Entry("request with empty upstream", &requestLoggerTableInput{ + Expected: &expectedFields{ + User: "-", Client: "127.0.0.1", Host: "test-server", + Method: "GET", URI: "/ping", Protocol: "HTTP/1.1", + Upstream: "-", StatusCode: 200, Size: 4, + RequestID: "11111111-2222-4333-8444-555555555555", + }, + Path: "/ping", + ExcludePaths: []string{""}, }), - Entry("custom format ping path excluded", &requestLoggerTableInput{ - Format: "{{.RequestMethod}}", - ExpectedLogMessage: "", - Path: "/ping", - ExcludePaths: []string{"/ping"}, + Entry("excluded path not matched", &requestLoggerTableInput{ + Expected: nil, + Path: "/ping", + ExcludePaths: []string{"/ping"}, }), ) }) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 53238f19..d321c5ac 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -119,10 +119,10 @@ func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { if err != nil && !errors.Is(err, http.ErrNoCookie) { // In the case when there was an error loading the session, // we should clear the session - logger.Errorf("Error loading cookied session: %v, removing session", err) + logger.ErrMsgf("error loading cookied session: %v, removing session", err) err = s.store.Clear(rw, req) if err != nil { - logger.Errorf("Error removing session: %v", err) + logger.ErrMsgf("error removing session: %v", err) } } @@ -186,7 +186,7 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req return } if err := session.ReleaseLock(req.Context()); err != nil { - logger.Errorf("unable to release lock: %v", err) + logger.ErrMsgf("unable to release lock: %v", err) } }() @@ -214,18 +214,20 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req } // We are holding the lock and the session needs a refresh - logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) + logger.Info("refreshing session", "user", session.User, "session_age", session.Age()) if err := s.refreshSession(rw, req, session); err != nil { - logger.Errorf("Unable to refresh session: %v", err) + // If a preemptive refresh fails, we still keep the session + // if validateSession succeeds. + logger.ErrMsgf("unable to refresh session: %v", err) // Check if this is a fatal error that indicates the session is revoked // or no longer valid at the provider level if isFatalRefreshError(err) { - logger.Printf("Fatal refresh error detected (session revoked or invalid), clearing session for user: %s", session.User) + logger.Warn("fatal refresh error detected; clearing session", "user", session.User, "error", err) // Clear the session from storage (Redis) and remove the cookie if err := s.store.Clear(rw, req); err != nil { - logger.Errorf("failed clearing session: %v", err) + logger.ErrMsgf("failed clearing session: %v", err) } // Return error immediately to force re-authentication @@ -265,7 +267,7 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R // Session not refreshed, nothing to persist. if !refreshed { - logger.Printf("Session not refreshed - User: %s; no refresh token available or provider returned false", session.User) + logger.Debug("session not refreshed", "user", session.User, "reason", "no refresh token available or provider returned false") return nil } @@ -276,7 +278,7 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R // Because the session was refreshed, make sure to save it err = s.store.Save(rw, req, session) if err != nil { - logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) + logger.LogAuth(session.Email, req, logger.AuthError, fmt.Sprintf("error saving session: %v", err)) return fmt.Errorf("error saving session: %v", err) } return nil diff --git a/pkg/providers/oidc/oidc_suite_test.go b/pkg/providers/oidc/oidc_suite_test.go index 84674dcd..ac83b3bb 100755 --- a/pkg/providers/oidc/oidc_suite_test.go +++ b/pkg/providers/oidc/oidc_suite_test.go @@ -1,6 +1,7 @@ package oidc import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,7 +10,7 @@ import ( ) func TestOIDCSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "OIDC") diff --git a/pkg/providers/oidc/provider.go b/pkg/providers/oidc/provider.go index 427f79a6..42a8219f 100644 --- a/pkg/providers/oidc/provider.go +++ b/pkg/providers/oidc/provider.go @@ -52,7 +52,7 @@ func NewProvider(ctx context.Context, issuerURL string, skipIssuerVerification b // (which uses discovery to get the URLs), so we'll do a quick check ourselves and if // we get the URLs, we'll just use the non-discovery path. - logger.Printf("Performing OIDC Discovery...") + logger.Info("performing OIDC Discovery...") var p providerJSON requestURL := strings.TrimSuffix(issuerURL, "/") + "/.well-known/openid-configuration" diff --git a/pkg/providers/util/util_suite_test.go b/pkg/providers/util/util_suite_test.go index b787cb23..ddefbe40 100644 --- a/pkg/providers/util/util_suite_test.go +++ b/pkg/providers/util/util_suite_test.go @@ -1,6 +1,7 @@ package util import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestProviderUtilSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Provider Utils") diff --git a/pkg/proxyhttp/http_suite_test.go b/pkg/proxyhttp/http_suite_test.go index 4241c7b6..703bc0e2 100644 --- a/pkg/proxyhttp/http_suite_test.go +++ b/pkg/proxyhttp/http_suite_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" + "log/slog" "net/http" "testing" @@ -22,8 +23,7 @@ var ipv6CertDataSource, ipv6KeyDataSource options.SecretSource var transport *http.Transport func TestHTTPSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "HTTP") diff --git a/pkg/proxyhttp/server.go b/pkg/proxyhttp/server.go index 2982e1fc..455a9267 100644 --- a/pkg/proxyhttp/server.go +++ b/pkg/proxyhttp/server.go @@ -346,11 +346,11 @@ func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { } err = tc.SetKeepAlive(true) if err != nil { - logger.Errorf("Error setting Keep-Alive: %v", err) + logger.ErrMsgf("error setting Keep-Alive: %v", err) } err = tc.SetKeepAlivePeriod(3 * time.Minute) if err != nil { - logger.Printf("Error setting Keep-Alive period: %v", err) + logger.ErrMsgf("error setting Keep-Alive period: %v", err) } return tc, nil } diff --git a/pkg/requests/requests_suite_test.go b/pkg/requests/requests_suite_test.go index 929ae8bd..c65720d9 100644 --- a/pkg/requests/requests_suite_test.go +++ b/pkg/requests/requests_suite_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log" + "log/slog" "net/http" "net/http/httptest" "testing" @@ -20,8 +21,7 @@ var ( ) func TestRequetsSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) log.SetOutput(GinkgoWriter) RegisterFailHandler(Fail) diff --git a/pkg/requests/util/util_suite_test.go b/pkg/requests/util/util_suite_test.go index 1cb07f37..698dab8d 100644 --- a/pkg/requests/util/util_suite_test.go +++ b/pkg/requests/util/util_suite_test.go @@ -1,6 +1,7 @@ package util_test import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -12,8 +13,7 @@ import ( // to prevent circular imports with the `logger` package which uses // this functionality func TestRequestUtilSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Request Utils") diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index a4da3734..bce39e50 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -186,7 +186,7 @@ func splitCookie(c *http.Cookie) []*http.Cookie { return []*http.Cookie{c} } - logger.Errorf("WARNING: Multiple cookies are required for this session as it exceeds the 4kb cookie limit. Please use server side session storage (eg. Redis) instead.") + logger.Warn("multiple cookies are required for this session as it exceeds the 4kb cookie limit. Please use server side session storage (eg. Redis) instead.") cookies := []*http.Cookie{} valueBytes := []byte(c.Value) diff --git a/pkg/sessions/cookie/session_store_test.go b/pkg/sessions/cookie/session_store_test.go index 5fc1ad78..7b6bee1d 100644 --- a/pkg/sessions/cookie/session_store_test.go +++ b/pkg/sessions/cookie/session_store_test.go @@ -2,6 +2,7 @@ package cookie import ( "fmt" + "log/slog" mathrand "math/rand" "net/http" "strings" @@ -18,8 +19,7 @@ import ( ) func TestSessionStore(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Cookie SessionStore") diff --git a/pkg/sessions/persistence/persistence_suite_test.go b/pkg/sessions/persistence/persistence_suite_test.go index 6edd912f..300ddf2c 100644 --- a/pkg/sessions/persistence/persistence_suite_test.go +++ b/pkg/sessions/persistence/persistence_suite_test.go @@ -1,6 +1,7 @@ package persistence import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestPersistenceSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Persistence") diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index 79f8f7d1..c43860cd 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -200,7 +200,7 @@ func setupTLSConfig(opts options.RedisStoreOptions, opt *redis.Options) error { if opts.CAPath != "" { rootCAs, err := x509.SystemCertPool() if err != nil { - logger.Errorf("failed to load system cert pool for redis connection, falling back to empty cert pool") + logger.ErrMsg("failed to load system cert pool for redis connection, falling back to empty cert pool") } if rootCAs == nil { rootCAs = x509.NewCertPool() @@ -212,7 +212,7 @@ func setupTLSConfig(opts options.RedisStoreOptions, opt *redis.Options) error { // Append our cert to the system pool if ok := rootCAs.AppendCertsFromPEM(certs); !ok { - logger.Errorf("no certs appended, using system certs only") + logger.ErrMsg("no certs appended, using system certs only") } if opt.TLSConfig == nil { diff --git a/pkg/sessions/redis/redis_test.go b/pkg/sessions/redis/redis_test.go index c38553cb..b19a7c4c 100644 --- a/pkg/sessions/redis/redis_test.go +++ b/pkg/sessions/redis/redis_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "encoding/pem" "log" + "log/slog" "os" "testing" @@ -32,8 +33,7 @@ var ( ) func TestRedis(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) redisLogger := &wrappedRedisLogger{Logger: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)} redisLogger.SetOutput(GinkgoWriter) diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 0db45f78..27b311c0 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -3,6 +3,7 @@ package sessions_test import ( "crypto/rand" "encoding/base64" + "log/slog" "testing" "time" @@ -17,8 +18,7 @@ import ( ) func TestSessionStore(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "SessionStore") diff --git a/pkg/upstream/proxy.go b/pkg/upstream/proxy.go index 857395ec..7ce8b79e 100644 --- a/pkg/upstream/proxy.go +++ b/pkg/upstream/proxy.go @@ -77,19 +77,19 @@ func (m *multiUpstreamProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request // registerStaticResponseHandler registers a static response handler with at the given path. func (m *multiUpstreamProxy) registerStaticResponseHandler(upstream options.Upstream, writer pagewriter.Writer) error { - logger.Printf("mapping path %q => static response %d", upstream.Path, ptr.Deref(upstream.StaticCode, options.DefaultUpstreamStaticCode)) + logger.Infof("mapping path %q => static response %d", upstream.Path, ptr.Deref(upstream.StaticCode, options.DefaultUpstreamStaticCode)) return m.registerHandler(upstream, newStaticResponseHandler(upstream.ID, upstream.StaticCode), writer) } // registerFileServer registers a new fileServer based on the configuration given. func (m *multiUpstreamProxy) registerFileServer(upstream options.Upstream, u *url.URL, writer pagewriter.Writer) error { - logger.Printf("mapping path %q => file system %q", upstream.Path, u.Path) + logger.Infof("mapping path %q => file system %q", upstream.Path, u.Path) return m.registerHandler(upstream, newFileServer(upstream, u.Path), writer) } // registerHTTPUpstreamProxy registers a new httpUpstreamProxy based on the configuration given. func (m *multiUpstreamProxy) registerHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *options.SignatureData, writer pagewriter.Writer) error { - logger.Printf("mapping path %q => upstream %q", upstream.Path, upstream.URI) + logger.Infof("mapping path %q => upstream %q", upstream.Path, upstream.URI) return m.registerHandler(upstream, newHTTPUpstreamProxy(upstream, u, sigData, writer.ProxyErrorHandler), writer) } diff --git a/pkg/upstream/rewrite.go b/pkg/upstream/rewrite.go index 343c740b..219c21b9 100644 --- a/pkg/upstream/rewrite.go +++ b/pkg/upstream/rewrite.go @@ -27,7 +27,7 @@ func rewritePath(rewriteRegExp *regexp.Regexp, rewriteTarget string, writer page return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { reqURL, err := url.ParseRequestURI(req.RequestURI) if err != nil { - logger.Errorf("could not parse request URI: %v", err) + logger.ErrMsgf("could not parse request URI: %v", err) writer.WriteErrorPage(rw, pagewriter.ErrorPageOpts{ Status: http.StatusInternalServerError, RequestID: middleware.GetRequestScope(req).RequestID, @@ -40,7 +40,7 @@ func rewritePath(rewriteRegExp *regexp.Regexp, rewriteTarget string, writer page newURI := rewriteRegExp.ReplaceAllString(reqURL.Path, rewriteTarget) reqURL.Path, reqURL.RawQuery, err = splitPathAndQuery(reqURL.Query(), newURI) if err != nil { - logger.Errorf("could not parse rewrite URI: %v", err) + logger.ErrMsgf("could not parse rewrite URI: %v", err) writer.WriteErrorPage(rw, pagewriter.ErrorPageOpts{ Status: http.StatusInternalServerError, RequestID: middleware.GetRequestScope(req).RequestID, diff --git a/pkg/upstream/static.go b/pkg/upstream/static.go index 6f002b8f..07683dad 100644 --- a/pkg/upstream/static.go +++ b/pkg/upstream/static.go @@ -35,6 +35,6 @@ func (s *staticResponseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Requ rw.WriteHeader(s.code) _, err := fmt.Fprintf(rw, "Authenticated") if err != nil { - logger.Errorf("Error writing static response: %v", err) + logger.ErrMsgf("error writing static response: %v", err) } } diff --git a/pkg/upstream/upstream_suite_test.go b/pkg/upstream/upstream_suite_test.go index 0ca23941..c622944f 100644 --- a/pkg/upstream/upstream_suite_test.go +++ b/pkg/upstream/upstream_suite_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log" + "log/slog" "net" "net/http" "net/http/httptest" @@ -28,8 +29,7 @@ var ( ) func TestUpstreamSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) log.SetOutput(GinkgoWriter) RegisterFailHandler(Fail) diff --git a/pkg/validation/logging.go b/pkg/validation/logging.go index c291405c..f6cce61d 100644 --- a/pkg/validation/logging.go +++ b/pkg/validation/logging.go @@ -1,16 +1,49 @@ package validation import ( + "fmt" + "log/slog" "os" + "strings" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "gopkg.in/natefinch/lumberjack.v2" ) +// parseLogLevel converts a string log level to slog.Level. +func parseLogLevel(s string) (slog.Level, error) { + switch strings.ToLower(s) { + case "debug": + return slog.LevelDebug, nil + case "info": + return slog.LevelInfo, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + default: + return slog.LevelInfo, fmt.Errorf("invalid log level %q: must be one of debug, info, warn, error", s) + } +} + // configureLogger is responsible for configuring the logger based on the options given func configureLogger(o options.Logging, msgs []string) []string { - // Setup the log file + // Parse and validate log level + level, err := parseLogLevel(o.Level) + if err != nil { + msgs = append(msgs, err.Error()) + return msgs + } + + // Validate log format + format := strings.ToLower(o.Format) + if format != "json" && format != "text" { + msgs = append(msgs, fmt.Sprintf("invalid log format %q: must be one of json, text", o.Format)) + return msgs + } + + // Determine output writers if len(o.File.Filename) > 0 { // Validate that the file/dir can be written file, err := os.OpenFile(o.File.Filename, os.O_WRONLY|os.O_CREATE, 0600) @@ -23,9 +56,14 @@ func configureLogger(o options.Logging, msgs []string) []string { if err != nil { return append(msgs, "error closing the log file: "+o.File.Filename) } + } - logger.Printf("Redirecting logging to file: %s", o.File.Filename) + // Setup writers + var stdWriter, errWriter *os.File + stdWriter = os.Stdout + errWriter = os.Stderr + if len(o.File.Filename) > 0 { logWriter := &lumberjack.Logger{ Filename: o.File.Filename, MaxSize: o.File.MaxSize, // megabytes @@ -35,28 +73,48 @@ func configureLogger(o options.Logging, msgs []string) []string { Compress: o.File.Compress, } - logger.SetOutput(logWriter) + // Setup with lumberjack writer + errW := errWriter + if o.ErrToInfo { + logger.Setup(level, format, logWriter, logWriter) + } else { + logger.Setup(level, format, logWriter, errW) + } + + logger.Info("logging redirected to file", "filename", o.File.Filename) + } else { + // Setup with stdout/stderr + if o.ErrToInfo { + logger.Setup(level, format, stdWriter, stdWriter) + } else { + logger.Setup(level, format, stdWriter, errWriter) + } } - // Supply a sanity warning to the logger if all logging is disabled - if !o.StandardEnabled && !o.AuthEnabled && !o.RequestEnabled { - logger.Error("Warning: Logging disabled. No further logs will be shown.") - } - - // Pass configuration values to the standard logger - logger.SetStandardEnabled(o.StandardEnabled) - logger.SetErrToInfo(o.ErrToInfo) - logger.SetAuthEnabled(o.AuthEnabled) - logger.SetReqEnabled(o.RequestEnabled) + logger.SetLocalTime(o.LocalTime) logger.SetStandardTemplate(o.StandardFormat) logger.SetAuthTemplate(o.AuthFormat) logger.SetReqTemplate(o.RequestFormat) + logger.SetErrToInfo(o.ErrToInfo) + logger.SetStandardEnabled(true) - logger.SetExcludePaths(o.ExcludePaths) - - if !o.LocalTime { - logger.SetFlags(logger.Flags() | logger.LUTC) + // Supply a sanity warning to the logger if all logging is disabled + if !o.StandardEnabled && !o.AuthEnabled && !o.RequestEnabled { + logger.Warn("logging disabled: standard, auth, and request logging are all off") } + logger.SetStandardEnabled(o.StandardEnabled) + + // Configure categorical logging + logger.SetAuthEnabled(o.AuthEnabled) + logger.SetReqEnabled(o.RequestEnabled) + + // Configure exclude paths + excludePaths := o.ExcludePaths + if o.SilencePing { + excludePaths = append(excludePaths, "/ping", "/ready") + } + logger.SetExcludePaths(excludePaths) + return msgs } diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 13ce2e0b..6f01682d 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -75,7 +75,7 @@ func Validate(o *options.Options) error { redirectURL, msgs = parseURL(o.RawRedirectURL, "redirect", msgs) o.SetRedirectURL(redirectURL) if o.RawRedirectURL == "" && !o.Cookie.Secure && !o.ReverseProxy { - logger.Print("WARNING: no explicit redirect URL: redirects will default to insecure HTTP") + logger.Warn("no explicit redirect URL: redirects will default to insecure HTTP") } msgs = append(msgs, validateUpstreams(o.UpstreamServers)...) @@ -108,7 +108,7 @@ func parseSignatureKey(o *options.Options, msgs []string) []string { return msgs } - logger.Print("WARNING: `--signature-key` is deprecated. It will be removed in a future release") + logger.Warn("`--signature-key` is deprecated. It will be removed in a future release") components := strings.Split(o.SignatureKey, ":") if len(components) != 2 { diff --git a/pkg/validation/validation_suite_test.go b/pkg/validation/validation_suite_test.go index e86f36f8..34dd8aa9 100644 --- a/pkg/validation/validation_suite_test.go +++ b/pkg/validation/validation_suite_test.go @@ -1,6 +1,7 @@ package validation import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestValidationSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Validation Suite") diff --git a/pkg/watcher/watcher.go b/pkg/watcher/watcher.go index b609e969..16148545 100644 --- a/pkg/watcher/watcher.go +++ b/pkg/watcher/watcher.go @@ -25,19 +25,19 @@ func WatchFileForUpdates(filename string, done <-chan bool, action func()) error for { select { case <-done: - logger.Printf("shutting down watcher for: %s", filename) + logger.Infof("shutting down watcher for: %s", filename) return case event := <-watcher.Events: filterEvent(watcher, event, filename, action) case err = <-watcher.Errors: - logger.Errorf("error watching '%s': %s", filename, err) + logger.ErrMsgf("error watching '%s': %s", filename, err) } } }() if err := watcher.Add(filename); err != nil { return fmt.Errorf("failed to add '%s' to watcher: %v", filename, err) } - logger.Printf("watching '%s' for updates", filename) + logger.Infof("watching '%s' for updates", filename) return nil } @@ -51,11 +51,11 @@ func filterEvent(watcher *fsnotify.Watcher, event fsnotify.Event, filename strin // In Kubernetes the file path is a symlink, so we should take action // when the ConfigMap/Secret is replaced. case event.Op&fsnotify.Remove != 0: - logger.Printf("watching interrupted on event: %s", event) + logger.Infof("watching interrupted on event: %s", event) WaitForReplacement(filename, event.Op, watcher) action() case event.Op&(fsnotify.Create|fsnotify.Write) != 0: - logger.Printf("reloading after event: %s", event) + logger.Infof("reloading after event: %s", event) action() } } @@ -72,7 +72,7 @@ func WaitForReplacement(filename string, op fsnotify.Op, watcher *fsnotify.Watch for { if _, err := os.Stat(filename); err == nil { if err := watcher.Add(filename); err == nil { - logger.Printf("watching resumed for '%s'", filename) + logger.Infof("watching resumed for '%s'", filename) return } } diff --git a/providers/azure.go b/providers/azure.go index ff22ebc3..1ec1ed13 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -92,7 +92,7 @@ func NewAzureProvider(p *ProviderData, opts options.AzureOptions) *AzureProvider azureV2GraphScope := fmt.Sprintf("https://%s/.default", p.ProfileURL.Host) if strings.Contains(p.Scope, " groups") { - logger.Print("WARNING: `groups` scope is not an accepted scope when using Azure OAuth V2 endpoint. Removing it from the scope list") + logger.Warn("`groups` scope is not an accepted scope when using Azure OAuth V2 endpoint. Removing it from the scope list") p.Scope = strings.ReplaceAll(p.Scope, " groups", "") } @@ -102,7 +102,7 @@ func NewAzureProvider(p *ProviderData, opts options.AzureOptions) *AzureProvider } if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { - logger.Print("WARNING: `--resource` option has no effect when using the Azure OAuth V2 endpoint.") + logger.Warn("`--resource` option has no effect when using the Azure OAuth V2 endpoint.") } } @@ -197,7 +197,7 @@ func (p *AzureProvider) EnrichSession(ctx context.Context, session *sessions.Ses err := p.extractClaimsIntoSession(ctx, session) if err != nil { - logger.Printf("unable to get email and/or groups claims from token: %v", err) + logger.Infof("unable to get email and/or groups claims from token: %v", err) } if session.Email == "" { @@ -288,7 +288,7 @@ func (p *AzureProvider) verifySessionToken(ctx context.Context, session *session if session.IDToken != "" { if _, err := p.Verifier.Verify(ctx, session.IDToken); err != nil { - logger.Printf("unable to verify ID token, fallback to access token: %v", err) + logger.Infof("unable to verify ID token, fallback to access token: %v", err) if _, err = p.Verifier.Verify(ctx, session.AccessToken); err != nil { return fmt.Errorf("unable to verify access token: %v", err) } @@ -353,7 +353,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess err = p.extractClaimsIntoSession(ctx, s) if err != nil { - logger.Printf("unable to get email and/or groups claims from token: %v", err) + logger.Infof("unable to get email and/or groups claims from token: %v", err) } return nil @@ -445,7 +445,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) { if err != nil || email == "" { email, err = json.Get("userPrincipalName").String() if err != nil { - logger.Errorf("unable to find userPrincipalName: %s", err) + logger.ErrMsgf("unable to find userPrincipalName: %s", err) return "", err } } diff --git a/providers/bitbucket.go b/providers/bitbucket.go index 75cd2926..9dad5c1a 100644 --- a/providers/bitbucket.go +++ b/providers/bitbucket.go @@ -117,7 +117,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses Do(). UnmarshalInto(&emails) if err != nil { - logger.Errorf("failed making request: %v", err) + logger.ErrMsgf("failed making request: %v", err) return "", err } @@ -133,7 +133,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses Do(). UnmarshalInto(&teams) if err != nil { - logger.Errorf("failed requesting teams membership: %v", err) + logger.ErrMsgf("failed requesting teams membership: %v", err) return "", err } var found = false @@ -144,7 +144,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses } } if !found { - logger.Error("team membership test failed, access denied") + logger.ErrMsg("team membership test failed, access denied") return "", nil } } @@ -163,7 +163,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses Do(). UnmarshalInto(&repositories) if err != nil { - logger.Errorf("failed checking repository access: %v", err) + logger.ErrMsgf("failed checking repository access: %v", err) return "", err } @@ -175,7 +175,7 @@ func (p *BitbucketProvider) GetEmailAddress(ctx context.Context, s *sessions.Ses } } if !found { - logger.Error("repository access test failed, access denied") + logger.ErrMsg("repository access test failed, access denied") return "", nil } } diff --git a/providers/cidaas.go b/providers/cidaas.go index 1d526dfc..7fa3c75d 100644 --- a/providers/cidaas.go +++ b/providers/cidaas.go @@ -72,7 +72,7 @@ func (p *CIDAASProvider) EnrichSession(ctx context.Context, s *sessions.SessionS // Try to get missing emails or groups from a profileURL if err := p.enrichFromUserinfoEndpoint(ctx, s); err != nil { - logger.Errorf("Warning: Profile URL request failed: %s", err) + logger.Warn("profile URL request failed", "error", err) } // If a mandatory email wasn't set, error at this point. diff --git a/providers/github.go b/providers/github.go index 09f10083..ad480d3c 100644 --- a/providers/github.go +++ b/providers/github.go @@ -165,13 +165,13 @@ func (p *GitHubProvider) hasOrg(s *sessions.SessionState) error { presentOrgs := make([]string, 0, len(orgs)) for _, org := range orgs { if p.Org == org { - logger.Printf("Found Github Organization:%q", org) + logger.Info("found Github organization", "org", org) return nil } presentOrgs = append(presentOrgs, org) } - logger.Printf("Missing Organization:%q in %v", p.Org, presentOrgs) + logger.Info("missing organization", "required_org", p.Org, "present_orgs", presentOrgs) return errors.New("user is missing required organization") } @@ -204,7 +204,7 @@ func (p *GitHubProvider) hasOrgAndTeam(s *sessions.SessionState) error { teams := strings.Split(p.Team, ",") for _, team := range teams { if strings.EqualFold(strings.TrimSpace(team), ot.Team) { - logger.Printf("Found Github Organization/Team:%q/%q", ot.Org, ot.Team) + logger.Info("found Github organization/team", "org", ot.Org, "team", ot.Team) return nil } } @@ -213,11 +213,11 @@ func (p *GitHubProvider) hasOrgAndTeam(s *sessions.SessionState) error { } if hasOrg { - logger.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) + logger.Info("missing team from org", "team", p.Team, "org", p.Org, "present_teams", presentTeams) return errors.New("user is missing required team") } - logger.Printf("Missing Organization:%q in %#v", p.Org, maps.Keys(presentOrgs)) + logger.Info("missing organization", "org", p.Org, "present_orgs", maps.Keys(presentOrgs)) return errors.New("user is missing required organization") } @@ -236,19 +236,19 @@ func (p *GitHubProvider) hasTeam(s *sessions.SessionState) error { allowedTeams := strings.Split(p.Team, ",") for _, team := range allowedTeams { if !strings.Contains(team, orgTeamSeparator) { - logger.Printf("Please use fully qualified team names (org:team-slug) if you omit the organisation. Current Team name: %s", team) + logger.Warn("please use fully qualified team names (org:team-slug) if you omit the organisation", "team", team) return errors.New("team name is invalid") } if strings.EqualFold(strings.TrimSpace(team), ot) { - logger.Printf("Found Github Organization/Team:%s", ot) + logger.Info("found Github organization/team", "team", ot) return nil } } presentTeams = append(presentTeams, ot) } - logger.Printf("Missing Team:%q in teams: %v", p.Team, presentTeams) + logger.Info("missing team", "team", p.Team, "present_teams", presentTeams) return errors.New("user is missing required team") } @@ -329,7 +329,7 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok result.StatusCode(), endpoint.String(), result.Body()) } - logger.Printf("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body()) + logger.Infof("got %d from %q %s", result.StatusCode(), endpoint.String(), result.Body()) return true, nil } @@ -497,10 +497,10 @@ func (p *GitHubProvider) getOrgs(ctx context.Context, s *sessions.SessionState) var orgName string if len(org.Login) > 0 { orgName = org.Login - logger.Printf("Member of Github Organization: %q", orgName) + logger.Info("member of Github organization", "org", orgName) } else { orgName = org.Name - logger.Printf("Member of Gitea Organization: %q", orgName) + logger.Info("member of Gitea organization", "org", orgName) } s.Groups = append(s.Groups, orgName) @@ -551,11 +551,11 @@ func (p *GitHubProvider) getTeams(ctx context.Context, s *sessions.SessionState) if len(team.Org.Login) > 0 { orgName = team.Org.Login teamName = team.Slug - logger.Printf("Member of Github Organization/Team: %q/%q", orgName, teamName) + logger.Info("member of Github organization/team", "org", orgName, "team", teamName) } else { orgName = team.Org.Name teamName = team.Name - logger.Printf("Member of Gitea Organization/Team: %q/%q", orgName, teamName) + logger.Info("member of Gitea organization/team", "org", orgName, "team", teamName) } s.Groups = append(s.Groups, fmt.Sprintf("%s%s%s", orgName, orgTeamSeparator, teamName)) diff --git a/providers/gitlab.go b/providers/gitlab.go index 1f5cfa16..f3f46c68 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -185,12 +185,12 @@ func (p *GitLabProvider) addProjectsToSession(ctx context.Context, s *sessions.S for _, project := range p.allowedProjects { projectInfo, err := p.getProjectInfo(ctx, s, project.Name) if err != nil { - logger.Errorf("Warning: project info request failed: %v", err) + logger.Warn("project info request failed", "error", err) continue } if projectInfo.Archived { - logger.Errorf("Warning: project %s is archived", project.Name) + logger.Warn("project is archived", "project", project.Name) continue } @@ -200,17 +200,17 @@ func (p *GitLabProvider) addProjectsToSession(ctx context.Context, s *sessions.S perms = projectInfo.Permissions.GroupAccess // group project access is not set for this user then we give up if perms == nil { - logger.Errorf("Warning: user %q has no project level access to %s", - s.Email, project.Name) + logger.Warn("user has no project level access", + "user", s.Email, "project", project.Name) continue } } if perms.AccessLevel < project.AccessLevel { - logger.Errorf( - "Warning: user %q does not have the minimum required access level for project %q", - s.Email, - project.Name, + logger.Warn( + "user does not have the minimum required access level for project", + "user", s.Email, + "project", project.Name, ) continue } diff --git a/providers/google.go b/providers/google.go index d28fb62d..80239466 100644 --- a/providers/google.go +++ b/providers/google.go @@ -284,7 +284,7 @@ func (p *GoogleProvider) populateAllGroups(adminService *admin.Service) func(s * // Get all groups of the user groups, err := getUserGroups(adminService, s.Email) if err != nil { - logger.Errorf("Failed to get user groups for %s: %v", s.Email, err) + logger.ErrMsgf("failed to get user groups for %s: %v", s.Email, err) s.Groups = []string{} return true // Allow access even if we can't get groups } @@ -311,24 +311,24 @@ func getOauth2TokenSource(ctx context.Context, opts options.GoogleOptions, scope Subject: opts.AdminEmail, }) if err != nil { - logger.Fatal("failed to fetch application default credentials: ", err) + logger.FatalMsg("failed to fetch application default credentials", "error", err) } return ts } credentialsReader, err := os.Open(opts.ServiceAccountJSON) if err != nil { - logger.Fatal("couldn't open Google credentials file: ", err) + logger.FatalMsg("couldn't open Google credentials file", "error", err) } data, err := io.ReadAll(credentialsReader) if err != nil { - logger.Fatal("can't read Google credentials file:", err) + logger.FatalMsg("can't read Google credentials file", "error", err) } conf, err := google.JWTConfigFromJSON(data, scope) if err != nil { - logger.Fatal("can't load Google credentials file:", err) + logger.FatalMsg("can't load Google credentials file", "error", err) } conf.Subject = opts.AdminEmail @@ -357,24 +357,24 @@ func getAdminService(opts options.GoogleOptions) *admin.Service { retrieveErrBody := map[string]interface{}{} if err := json.Unmarshal(retrieveErr.Body, &retrieveErrBody); err != nil { - logger.Fatal("error unmarshalling retrieveErr body:", err) + logger.FatalMsg("error unmarshalling retrieveErr body", "error", err) } if retrieveErrBody["error"] == "unauthorized_client" && retrieveErrBody["error_description"] == "Client is unauthorized to retrieve access tokens using this method, or client not authorized for any of the scopes requested." { continue } - logger.Fatal("error retrieving token:", err) + logger.FatalMsg("error retrieving token", "error", err) } } if client == nil { - logger.Fatal("error: google credentials do not have enough permissions to access admin API scope") + logger.FatalMsg("error: google credentials do not have enough permissions to access admin API scope") } adminService, err := admin.NewService(ctx, option.WithHTTPClient(client)) if err != nil { - logger.Fatal(err) + logger.FatalMsg("failed to create admin service", "error", err) } return adminService } @@ -385,26 +385,26 @@ func getTargetPrincipal(ctx context.Context, opts options.GoogleOptions) (target if targetPrincipal != "" { return targetPrincipal } - logger.Print("INFO: no target principal set, trying to automatically determine one instead.") + logger.Info("no target principal set, trying to automatically determine one instead.") credential, err := google.FindDefaultCredentials(ctx) if err != nil { - logger.Fatal("failed to fetch application default credentials: ", err) + logger.FatalMsg("failed to fetch application default credentials", "error", err) } content := map[string]interface{}{} err = json.Unmarshal(credential.JSON, &content) switch { case err != nil && !metadata.OnGCE(): - logger.Fatal("unable to unmarshal Application Default Credentials JSON", err) + logger.FatalMsg("unable to unmarshal Application Default Credentials JSON", "error", err) case content["client_email"] != nil: targetPrincipal = fmt.Sprintf("%v", content["client_email"]) case metadata.OnGCE(): targetPrincipal, err = metadata.EmailWithContext(ctx, "") if err != nil { - logger.Fatal("error while calling the GCE metadata server", err) + logger.FatalMsg("error while calling the GCE metadata server", "error", err) } default: - logger.Fatal("unable to determine Application Default Credentials TargetPrincipal, try overriding with --target-principal instead.") + logger.FatalMsg("unable to determine Application Default Credentials TargetPrincipal, try overriding with --target-principal instead.") } return targetPrincipal } @@ -476,7 +476,7 @@ func userInGroup(service *admin.Service, group string, email string) bool { gerr, ok := err.(*googleapi.Error) switch { case ok && gerr.Code == 404: - logger.Errorf("error checking membership in group %s: group does not exist", group) + logger.ErrMsg("error checking membership in group: group does not exist", "group", group) case ok && gerr.Code == 400: // It is possible for Members.HasMember to return false even if the email is a group member. // One case that can cause this is if the user email is from a different domain than the group, @@ -485,7 +485,7 @@ func userInGroup(service *admin.Service, group string, email string) bool { req := service.Members.Get(group, email) r, err := req.Do() if err != nil { - logger.Errorf("error using get API to check member %s of google group %s: user not in the group", email, group) + logger.ErrMsg("error using get API to check member of google group: user not in the group", "email", email, "group", group) return false } @@ -495,7 +495,7 @@ func userInGroup(service *admin.Service, group string, email string) bool { return true } default: - logger.Errorf("error checking group membership: %v", err) + logger.ErrMsgf("error checking group membership: %v", err) } return false } diff --git a/providers/internal_util.go b/providers/internal_util.go index 49e1fd94..5961f73c 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -24,14 +24,14 @@ func stripToken(endpoint string) string { func stripParam(param, endpoint string) string { u, err := url.Parse(endpoint) if err != nil { - logger.Errorf("error attempting to strip %s: %s", param, err) + logger.ErrMsgf("error attempting to strip %s: %s", param, err) return endpoint } if u.RawQuery != "" { values, err := url.ParseQuery(u.RawQuery) if err != nil { - logger.Errorf("error attempting to strip %s: %s", param, err) + logger.ErrMsgf("error attempting to strip %s: %s", param, err) return u.String() } @@ -66,17 +66,17 @@ func validateToken(ctx context.Context, p Provider, accessToken string, header h WithHeaders(header). Do() if result.Error() != nil { - logger.Errorf("GET %s", stripToken(endpoint)) - logger.Errorf("token validation request failed: %s", result.Error()) + logger.ErrMsgf("GET %s", stripToken(endpoint)) + logger.ErrMsgf("token validation request failed: %s", result.Error()) return false } - logger.Printf("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body()) + logger.Infof("%d GET %s %s", result.StatusCode(), stripToken(endpoint), result.Body()) if result.StatusCode() == 200 { return true } - logger.Errorf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()) + logger.ErrMsgf("token validation request failed: status %d - %s", result.StatusCode(), result.Body()) return false } diff --git a/providers/keycloak.go b/providers/keycloak.go index e36068c4..f7156547 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -79,7 +79,7 @@ func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.Sessio Do(). UnmarshalSimpleJSON() if err != nil { - logger.Errorf("failed making request %v", err) + logger.ErrMsgf("failed making request %v", err) return err } diff --git a/providers/ms_entra_id.go b/providers/ms_entra_id.go index 97a18e48..80de5573 100644 --- a/providers/ms_entra_id.go +++ b/providers/ms_entra_id.go @@ -70,7 +70,7 @@ func (p *MicrosoftEntraIDProvider) EnrichSession(ctx context.Context, session *s } if hasGroupOverage { - logger.Printf("entra overage found, reading groups from Graph API") + logger.Info("entra overage found, reading groups from Graph API") if err = p.addGraphGroupsToSession(ctx, session); err != nil { return fmt.Errorf("unable to enrich session: %v", err) } @@ -83,17 +83,17 @@ func (p *MicrosoftEntraIDProvider) EnrichSession(ctx context.Context, session *s func (p *MicrosoftEntraIDProvider) ValidateSession(ctx context.Context, session *sessions.SessionState) bool { tenant, err := p.getTenantFromToken(session) if err != nil { - logger.Errorf("unable to retrieve entra tenant from token: %v", err) + logger.ErrMsgf("unable to retrieve entra tenant from token: %v", err) return false } if len(p.multiTenantAllowedTenants) > 0 { tenantAllowed := p.checkTenantMatchesTenantList(tenant, p.multiTenantAllowedTenants) if !tenantAllowed { - logger.Printf("entra: tenant %s is not specified in the list of allowed tenants", tenant) + logger.Info("entra: tenant not in allowed list", "tenant", tenant) return false } - logger.Printf("entra: tenant %s is allowed", tenant) + logger.Info("entra: tenant is allowed", "tenant", tenant) } return p.OIDCProvider.ValidateSession(ctx, session) @@ -247,7 +247,7 @@ func (p *MicrosoftEntraIDProvider) addGraphGroupsToSession(ctx context.Context, UnmarshalSimpleJSON() if err != nil { - logger.Errorf("invalid response from microsoft graph, no groups added to session: %v", err) + logger.ErrMsgf("invalid response from microsoft graph, no groups added to session: %v", err) return nil } reqGroups := response.Get("value").MustArray() diff --git a/providers/nextcloud.go b/providers/nextcloud.go index 6c791d27..fc9bd144 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -49,7 +49,7 @@ func (p *NextcloudProvider) EnrichSession(ctx context.Context, s *sessions.Sessi Do(). UnmarshalSimpleJSON() if err != nil { - logger.Errorf("failed making request %v", err) + logger.ErrMsgf("failed making request %v", err) return err } diff --git a/providers/oidc.go b/providers/oidc.go index aa022f63..cde7de73 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -120,14 +120,14 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS if s.Refreshed { validateEndpointAvailable := p.Data().ValidateURL != nil && p.Data().ValidateURL.String() != "" if validateEndpointAvailable && !validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) { - logger.Errorf("access_token validation failed") + logger.ErrMsg("access_token validation failed") return false } return true } if _, err := p.Verifier.Verify(ctx, s.IDToken); err != nil { - logger.Errorf("id_token verification failed: %v", err) + logger.ErrMsgf("id_token verification failed: %v", err) return false } @@ -136,7 +136,7 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS } if err := p.checkNonce(s); err != nil { - logger.Errorf("nonce verification failed: %v", err) + logger.ErrMsgf("nonce verification failed: %v", err) return false } diff --git a/providers/provider_data.go b/providers/provider_data.go index 80bd77ae..27d9668e 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -89,7 +89,7 @@ func (p *ProviderData) GetClientSecret() (clientSecret string, err error) { // Getting ClientSecret can fail in runtime so we need to report it without returning the file name to the user fileClientSecret, err := os.ReadFile(p.ClientSecretFile) if err != nil { - logger.Errorf("error reading client secret file %s: %s", p.ClientSecretFile, err) + logger.ErrMsgf("error reading client secret file %s: %s", p.ClientSecretFile, err) return "", errors.New("could not read client secret file") } return string(fileClientSecret), nil @@ -331,7 +331,7 @@ func (p *ProviderData) extractAdditionalClaims(extractor util.ClaimExtractor, ss for _, claim := range p.AdditionalClaims { value, exists, err := extractor.GetClaim(claim) if err != nil { - logger.Printf("error extracting additional claim %q: %v", claim, err) + logger.Warn("error extracting additional claim", "claim", claim, "error", err) continue } if exists { diff --git a/providers/providers.go b/providers/providers.go index f87d26a2..24c7c14e 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -154,7 +154,7 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData, // Set PKCE enabled or disabled based on discovery and force options p.CodeChallengeMethod = parseCodeChallengeMethod(providerConfig) if len(p.SupportedCodeChallengeMethods) != 0 && p.CodeChallengeMethod == "" { - logger.Printf("Warning: Your provider supports PKCE methods %+q, but you have not enabled one with --code-challenge-method", p.SupportedCodeChallengeMethods) + logger.Warn("provider supports PKCE but no code-challenge-method is enabled", "supported_methods", p.SupportedCodeChallengeMethods) } if providerConfig.OIDCConfig.UserIDClaim == "" { diff --git a/providers/providers_suite_test.go b/providers/providers_suite_test.go index 2b3fa04a..d443602c 100644 --- a/providers/providers_suite_test.go +++ b/providers/providers_suite_test.go @@ -1,6 +1,7 @@ package providers_test import ( + "log/slog" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -9,8 +10,7 @@ import ( ) func TestProviderSuite(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) + logger.Setup(slog.LevelDebug, "text", GinkgoWriter, GinkgoWriter) RegisterFailHandler(Fail) RunSpecs(t, "Providers") diff --git a/providers/srht.go b/providers/srht.go index e927629e..8068cfc9 100644 --- a/providers/srht.go +++ b/providers/srht.go @@ -83,7 +83,7 @@ func (p *SourceHutProvider) EnrichSession(ctx context.Context, s *sessions.Sessi Do(). UnmarshalSimpleJSON() if err != nil { - logger.Errorf("failed making request %v", err) + logger.ErrMsgf("failed making request %v", err) return err } diff --git a/validator.go b/validator.go index d03157a5..86021415 100644 --- a/validator.go +++ b/validator.go @@ -26,7 +26,7 @@ func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { m := make(map[string]bool) atomic.StorePointer(&um.m, unsafe.Pointer(&m)) // #nosec G103 if usersFile != "" { - logger.Printf("using authenticated emails file %s", usersFile) + logger.Info("using authenticated emails file", "path", usersFile) watcher.WatchFileForUpdates(usersFile, done, func() { um.LoadAuthenticatedEmailsFile() onUpdate() @@ -48,12 +48,12 @@ func (um *UserMap) IsValid(email string) (result bool) { func (um *UserMap) LoadAuthenticatedEmailsFile() { r, err := os.Open(um.usersFile) if err != nil { - logger.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) + logger.FatalMsgf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) } defer func(c io.Closer) { cerr := c.Close() if cerr != nil { - logger.Fatalf("Error closing authenticated emails file: %s", cerr) + logger.FatalMsgf("error closing authenticated emails file: %s", cerr) } }(r) csvReader := csv.NewReader(r) @@ -62,7 +62,7 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() { csvReader.TrimLeadingSpace = true records, err := csvReader.ReadAll() if err != nil { - logger.Errorf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) + logger.ErrMsgf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) return } updated := make(map[string]bool)