oauth2-proxy/pkg/logger/logger_test.go

595 lines
16 KiB
Go

package logger
import (
"bytes"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"text/template"
"time"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
)
// resetLogger resets the logger to defaults for test isolation.
func resetLogger(t *testing.T) {
t.Helper()
logLevel.Set(slog.LevelInfo)
logFormat = "text"
standardEnabled = true
localTime = true
flags = LstdFlags
stdLogTemplate = template.Must(template.New("std-log").Parse(DefaultStandardLoggingFormat))
authTemplate = template.Must(template.New("auth-log").Parse(DefaultAuthLoggingFormat))
reqTemplate = template.Must(template.New("req-log").Parse(DefaultRequestLoggingFormat))
authEnabled = true
reqEnabled = true
excludePaths = nil
getClientFunc = func(r *http.Request) string { return r.RemoteAddr }
errToInfo = false
}
// parseJSON parses a JSON log line into a map.
func parseJSON(t *testing.T, data []byte) map[string]any {
t.Helper()
// Take only the first line if there are multiple
line := strings.TrimSpace(strings.Split(string(data), "\n")[0])
if line == "" {
t.Fatal("empty log output")
}
var m map[string]any
if err := json.Unmarshal([]byte(line), &m); err != nil {
t.Fatalf("failed to parse JSON log: %v\nraw: %s", err, line)
}
return m
}
func TestSetup_JSONFormat(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "json", buf, errBuf)
Info("hello", "key", "val")
m := parseJSON(t, buf.Bytes())
if m["msg"] != "hello" {
t.Errorf("expected msg=hello, got %v", m["msg"])
}
if m["key"] != "val" {
t.Errorf("expected key=val, got %v", m["key"])
}
if m["level"] != "INFO" {
t.Errorf("expected level=INFO, got %v", m["level"])
}
if _, ok := m["time"]; !ok {
t.Error("expected time field in JSON output")
}
if _, ok := m["source"]; !ok {
t.Error("expected source field in JSON output")
}
}
func TestSetup_TextFormat(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "text", buf, errBuf)
Info("hello", "key", "val")
out := buf.String()
if strings.Contains(out, "level=INFO") {
t.Errorf("did not expect slog key=value text output, got: %s", out)
}
if !strings.Contains(out, "hello key=val") {
t.Errorf("expected message and attrs in template output, got: %s", out)
}
if !strings.HasSuffix(out, "\n") {
t.Errorf("expected text output to end with newline, got: %s", out)
}
}
func TestTextFormat_StandardTemplate(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "text", buf, errBuf)
SetStandardTemplate("{{.Message}}")
Info("hello", "key", "val")
if got, want := buf.String(), "hello key=val\n"; got != want {
t.Errorf("expected custom standard template output %q, got %q", want, got)
}
}
func TestTextFormat_StandardTemplateFileUsesCaller(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "text", buf, errBuf)
SetStandardTemplate("{{.File}}|{{.Message}}")
Info("hello")
out := buf.String()
if !strings.Contains(out, "logger_test.go:") {
t.Errorf("expected caller file in text output, got: %s", out)
}
if strings.Contains(out, "logger.go:") {
t.Errorf("expected external caller file, got logger wrapper file: %s", out)
}
}
func TestSetFlags_ControlsStandardTemplateFile(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "text", buf, errBuf)
SetStandardTemplate("{{.File}}")
SetFlags(0)
Info("hello")
out := buf.String()
if !strings.Contains(out, "/pkg/logger/logger_test.go:") {
t.Errorf("expected full caller file when Lshortfile is disabled, got: %s", out)
}
}
func TestTextFormat_StandardLoggingDisabled(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "text", buf, errBuf)
SetStandardEnabled(false)
Info("hidden")
Warn("also hidden")
if buf.Len() > 0 || errBuf.Len() > 0 {
t.Errorf("expected standard logs to be disabled, stdout=%q stderr=%q", buf.String(), errBuf.String())
}
}
func TestLevelFiltering(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "json", buf, errBuf)
Debug("should not appear")
if buf.Len() > 0 {
t.Error("Debug message should be filtered at Info level")
}
Info("should appear")
if buf.Len() == 0 {
t.Error("Info message should appear at Info level")
}
m := parseJSON(t, buf.Bytes())
if m["msg"] != "should appear" {
t.Errorf("expected msg='should appear', got %v", m["msg"])
}
// Error should go to errBuf
ErrMsg("error msg")
if errBuf.Len() == 0 {
t.Error("Error message should appear in errBuf")
}
}
func TestSetLevel_Runtime(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "json", buf, errBuf)
Debug("hidden")
if buf.Len() > 0 {
t.Error("Debug should be hidden at Info level")
}
SetLevel(slog.LevelDebug)
Debug("visible")
if buf.Len() == 0 {
t.Error("Debug should be visible after SetLevel(Debug)")
}
m := parseJSON(t, buf.Bytes())
if m["msg"] != "visible" {
t.Errorf("expected msg='visible', got %v", m["msg"])
}
}
func TestLevelSplitHandler(t *testing.T) {
resetLogger(t)
stdBuf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", stdBuf, errBuf)
Info("info msg")
if stdBuf.Len() == 0 {
t.Error("Info should go to stdout")
}
if errBuf.Len() > 0 {
t.Error("Info should NOT go to stderr")
}
stdBuf.Reset()
Warn("warn msg")
if errBuf.Len() == 0 {
t.Error("Warn should go to stderr")
}
if stdBuf.Len() > 0 {
t.Error("Warn should NOT go to stdout")
}
errBuf.Reset()
ErrMsg("error msg")
if errBuf.Len() == 0 {
t.Error("Error should go to stderr")
}
}
func TestErrToInfo(t *testing.T) {
resetLogger(t)
stdBuf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "json", stdBuf, errBuf)
SetErrToInfo(true)
ErrMsg("goes to stdout")
if stdBuf.Len() == 0 {
t.Error("Error should go to stdout when ErrToInfo is true")
}
if errBuf.Len() > 0 {
t.Error("Error should NOT go to stderr when ErrToInfo is true")
}
}
func TestLogAuth_Success(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", buf, errBuf)
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
scope := &middlewareapi.RequestScope{RequestID: "test-request-id"}
req = middlewareapi.AddRequestScope(req, scope)
LogAuth("user@test.com", req, AuthSuccess, "authenticated via OAuth2")
m := parseJSON(t, buf.Bytes())
if m["level"] != "INFO" {
t.Errorf("AuthSuccess should log at INFO, got %v", m["level"])
}
if m["user"] != "user@test.com" {
t.Errorf("expected user=user@test.com, got %v", m["user"])
}
if m["status"] != "AuthSuccess" {
t.Errorf("expected status=AuthSuccess, got %v", m["status"])
}
if m["request_id"] != "test-request-id" {
t.Errorf("expected request_id=test-request-id, got %v", m["request_id"])
}
if m["msg"] != "authenticated via OAuth2" {
t.Errorf("expected msg='authenticated via OAuth2', got %v", m["msg"])
}
}
func TestLogAuth_Failure(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", buf, errBuf)
req := httptest.NewRequest("GET", "/test", nil)
scope := &middlewareapi.RequestScope{RequestID: "req-id"}
req = middlewareapi.AddRequestScope(req, scope)
LogAuth("bad-user", req, AuthFailure, "invalid credentials")
// AuthFailure → Warn → goes to errBuf
m := parseJSON(t, errBuf.Bytes())
if m["level"] != "WARN" {
t.Errorf("AuthFailure should log at WARN, got %v", m["level"])
}
if m["status"] != "AuthFailure" {
t.Errorf("expected status=AuthFailure, got %v", m["status"])
}
}
func TestLogAuth_Error(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", buf, errBuf)
req := httptest.NewRequest("GET", "/test", nil)
scope := &middlewareapi.RequestScope{RequestID: "req-id"}
req = middlewareapi.AddRequestScope(req, scope)
LogAuth("user", req, AuthError, "internal error")
m := parseJSON(t, errBuf.Bytes())
if m["level"] != "ERROR" {
t.Errorf("AuthError should log at ERROR, got %v", m["level"])
}
}
func TestLogAuth_Disabled(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", buf, errBuf)
SetAuthEnabled(false)
req := httptest.NewRequest("GET", "/test", nil)
scope := &middlewareapi.RequestScope{RequestID: "req-id"}
req = middlewareapi.AddRequestScope(req, scope)
LogAuth("user", req, AuthSuccess, "should not appear")
if buf.Len() > 0 || errBuf.Len() > 0 {
t.Error("LogAuth should produce no output when auth logging is disabled")
}
}
func TestLogAuth_TextTemplate(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "text", buf, errBuf)
SetAuthTemplate("{{.Client}}|{{.RequestID}}|{{.Username}}|{{.Status}}|{{.Message}}")
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:1234"
scope := &middlewareapi.RequestScope{RequestID: "req-id"}
req = middlewareapi.AddRequestScope(req, scope)
LogAuth("user@test.com", req, AuthSuccess, "authenticated", "provider", "oidc")
if got, want := buf.String(), "192.168.1.1:1234|req-id|user@test.com|AuthSuccess|authenticated provider=oidc\n"; got != want {
t.Errorf("expected custom auth template output %q, got %q", want, got)
}
if errBuf.Len() > 0 {
t.Errorf("expected text auth log to use standard writer, got stderr %q", errBuf.String())
}
}
func TestLogRequest(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", buf, errBuf)
req := httptest.NewRequest("GET", "/foo/bar", nil)
req.RemoteAddr = "127.0.0.1:5678"
scope := &middlewareapi.RequestScope{RequestID: "req-123"}
req = middlewareapi.AddRequestScope(req, scope)
reqURL := *req.URL
LogRequest("testuser", "backend", req, reqURL, time.Now(), 200, 1024)
// LogRequest uses time.Now() internally so we can't easily test duration_s exactly.
// Just parse and check fields.
m := parseJSON(t, buf.Bytes())
if m["msg"] != "request" {
t.Errorf("expected msg=request, got %v", m["msg"])
}
if m["level"] != "INFO" {
t.Errorf("expected level=INFO, got %v", m["level"])
}
if m["user"] != "testuser" {
t.Errorf("expected user=testuser, got %v", m["user"])
}
if m["upstream"] != "backend" {
t.Errorf("expected upstream=backend, got %v", m["upstream"])
}
if m["method"] != "GET" {
t.Errorf("expected method=GET, got %v", m["method"])
}
// status_code comes as float64 from JSON
if sc, ok := m["status_code"].(float64); !ok || int(sc) != 200 {
t.Errorf("expected status_code=200, got %v", m["status_code"])
}
if rs, ok := m["response_size"].(float64); !ok || int(rs) != 1024 {
t.Errorf("expected response_size=1024, got %v", m["response_size"])
}
if _, ok := m["duration_s"]; !ok {
t.Error("expected duration_s field")
}
if m["request_id"] != "req-123" {
t.Errorf("expected request_id=req-123, got %v", m["request_id"])
}
}
func TestLogRequest_ExcludePaths(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", buf, errBuf)
SetExcludePaths([]string{"/healthz", "/ping"})
req := httptest.NewRequest("GET", "/healthz", nil)
scope := &middlewareapi.RequestScope{RequestID: "req-id"}
req = middlewareapi.AddRequestScope(req, scope)
reqURL := *req.URL
LogRequest("user", "-", req, reqURL, time.Now(), 200, 0)
if buf.Len() > 0 {
t.Error("LogRequest should not log excluded paths")
}
}
func TestLogRequest_Disabled(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", buf, errBuf)
SetReqEnabled(false)
req := httptest.NewRequest("GET", "/foo", nil)
scope := &middlewareapi.RequestScope{RequestID: "req-id"}
req = middlewareapi.AddRequestScope(req, scope)
reqURL := *req.URL
LogRequest("user", "-", req, reqURL, time.Now(), 200, 0)
if buf.Len() > 0 {
t.Error("LogRequest should produce no output when request logging is disabled")
}
}
func TestLogRequest_TextTemplate(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "text", buf, errBuf)
SetReqTemplate("{{.Username}}|{{.Upstream}}|{{.RequestMethod}}|{{.RequestURI}}|{{.StatusCode}}|{{.ResponseSize}}|{{.RequestID}}")
req := httptest.NewRequest("GET", "/foo/bar?x=1", nil)
req.RemoteAddr = "127.0.0.1:5678"
scope := &middlewareapi.RequestScope{RequestID: "req-123"}
req = middlewareapi.AddRequestScope(req, scope)
reqURL := *req.URL
LogRequest("testuser", "backend", req, reqURL, time.Now(), 204, 42)
if got, want := buf.String(), "testuser|backend|GET|\"/foo/bar?x=1\"|204|42|req-123\n"; got != want {
t.Errorf("expected custom request template output %q, got %q", want, got)
}
}
func TestFatalMsg(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "json", buf, errBuf)
// Override exitFunc to capture the exit code
var exitCode int
exitFunc = func(code int) { exitCode = code }
defer func() { exitFunc = nil }() // will panic on real exit if not restored
FatalMsg("fatal error")
if exitCode != 1 {
t.Errorf("expected exit code 1, got %d", exitCode)
}
if errBuf.Len() == 0 {
t.Error("Fatal should produce error-level output")
}
m := parseJSON(t, errBuf.Bytes())
if m["level"] != "ERROR" {
t.Errorf("expected level=ERROR, got %v", m["level"])
}
// Restore
exitFunc = func(code int) {}
}
func TestInfof(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "json", buf, errBuf)
Infof("hello %s", "world")
m := parseJSON(t, buf.Bytes())
if m["msg"] != "hello world" {
t.Errorf("expected msg='hello world', got %v", m["msg"])
}
if m["level"] != "INFO" {
t.Errorf("expected level=INFO, got %v", m["level"])
}
}
func TestErrMsgf(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "json", buf, errBuf)
ErrMsgf("error: %s", "something")
m := parseJSON(t, errBuf.Bytes())
if m["msg"] != "error: something" {
t.Errorf("expected msg='error: something', got %v", m["msg"])
}
if m["level"] != "ERROR" {
t.Errorf("expected level=ERROR, got %v", m["level"])
}
}
func TestLogAuth(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelDebug, "json", buf, errBuf)
req := httptest.NewRequest("GET", "/test", nil)
scope := &middlewareapi.RequestScope{RequestID: "req-id"}
req = middlewareapi.AddRequestScope(req, scope)
LogAuth("user@test.com", req, AuthSuccess, "authenticated via OAuth2")
m := parseJSON(t, buf.Bytes())
if m["msg"] != "authenticated via OAuth2" {
t.Errorf("expected msg='authenticated via OAuth2', got %v", m["msg"])
}
if m["user"] != "user@test.com" {
t.Errorf("expected user=user@test.com, got %v", m["user"])
}
}
func TestSetOutput(t *testing.T) {
resetLogger(t)
buf1 := &bytes.Buffer{}
errBuf := &bytes.Buffer{}
Setup(slog.LevelInfo, "json", buf1, errBuf)
Info("before")
if buf1.Len() == 0 {
t.Error("expected output in buf1")
}
buf2 := &bytes.Buffer{}
SetOutput(buf2)
buf1.Reset()
Info("after")
if buf2.Len() == 0 {
t.Error("expected output in buf2 after SetOutput")
}
if buf1.Len() > 0 {
t.Error("buf1 should have no new output after SetOutput")
}
}
func TestGetLevel(t *testing.T) {
resetLogger(t)
buf := &bytes.Buffer{}
Setup(slog.LevelWarn, "json", buf, buf)
if GetLevel() != slog.LevelWarn {
t.Errorf("expected LevelWarn, got %v", GetLevel())
}
SetLevel(slog.LevelDebug)
if GetLevel() != slog.LevelDebug {
t.Errorf("expected LevelDebug, got %v", GetLevel())
}
}