package logging_test import ( "errors" "log/slog" "strings" "testing" "go.sudomsg.com/kit/logging" "go.sudomsg.com/kit/logging/test" ) func TestWithLogger_And_FromContext(t *testing.T) { ctx := t.Context() mock := test.NewMockLogHandler(t) logger := slog.New(mock) ctx = logging.WithLogger(ctx, logger) got := logging.FromContext(ctx) if got != logger { t.Errorf("expected logger from context to match original logger") } } func TestFromContext_DefaultFallback(t *testing.T) { ctx := t.Context() got := logging.FromContext(ctx) if got != slog.Default() { t.Errorf("expected default logger when no logger is in context") } } func TestRecoverAndLog(t *testing.T) { ctx := t.Context() mock := test.NewMockLogHandler(t) logger := slog.New(mock) ctx = logging.WithLogger(ctx, logger) err := errors.New("something broke") logging.RecoverAndLog(ctx, "panic recovered", err) records := mock.Records() if len(records) != 1 { t.Fatalf("expected 1 log record, got %d", len(records)) } record := records[0] if record.Message != "panic recovered" { t.Errorf("expected message 'panic recAvered', got %q", record.Message) } foundStack := false record.Attrs(func(a slog.Attr) bool { if a.Key == "stack" { foundStack = true } return true }) if !foundStack { t.Errorf("expected 'stack' attribute in log") } } func TestWith(t *testing.T) { ctx := t.Context() mock := test.NewMockLogHandler(t) logger := slog.New(mock) ctx = logging.WithLogger(ctx, logger) user := "user" id := "1234" logger2, newCtx := logging.With(ctx, "user", user, "id", id) logger2.InfoContext(ctx, "test message") records := mock.Records() if len(records) != 1 { t.Fatalf("expected 1 record, got %d", len(records)) } record := records[0] foundUser := false foundID := false record.Attrs(func(attr slog.Attr) bool { switch attr.Key { case "user": foundUser = attr.Value.String() == user case "id": foundID = attr.Value.String() == id } return true }) if !foundUser || !foundID { t.Errorf("expected 'user' and 'id' attributes in log record, %v", records) } // Test context carries logger with same attributes logFromCtx := logging.FromContext(newCtx) logFromCtx.InfoContext(ctx, "second message") if len(mock.Records()) != 2 { t.Errorf("expected 2 log records, got %d", len(mock.Records())) } } func TestWithGroup(t *testing.T) { ctx := t.Context() mock := test.NewMockLogHandler(t) logger := slog.New(mock) ctx = logging.WithLogger(ctx, logger) logger2, _ := logging.WithGroup(ctx, "foo") logger2.InfoContext(ctx, "test message", "key", "value") records := mock.Records() if len(records) != 1 { t.Fatalf("expected 1 record, got %d", len(records)) } record := records[0] var foundGroup bool record.Attrs(func(attr slog.Attr) bool { foundGroup = attr.Value.Kind() == slog.KindGroup return false }) if !foundGroup { t.Errorf("expected 'user' and 'id' attributes in log record, %v", records) } } func TestLogSink(t *testing.T) { t.Parallel() t.Run("RoundTrip", func(t *testing.T) { t.Parallel() testCases := []logging.LogSink{ logging.SinkStdout, logging.SinkStderr, logging.SinkFile, } for _, original := range testCases { t.Run(original.String(), func(t *testing.T) { t.Run("case same", func(t *testing.T) { t.Run(original.String(), func(t *testing.T) { text, err := original.MarshalText() if err != nil { t.Fatalf("MarshalText failed: %v", err) } var decoded logging.LogSink if err := decoded.UnmarshalText(text); err != nil { t.Fatalf("UnmarshalText failed: %v", err) } if decoded != original { t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) } }) }) t.Run("case not same", func(t *testing.T) { t.Run(original.String(), func(t *testing.T) { text, err := original.MarshalText() if err != nil { t.Fatalf("MarshalText failed: %v", err) } var decoded logging.LogSink if err := decoded.UnmarshalText([]byte(strings.ToUpper(string(text)))); err != nil { t.Fatalf("UnmarshalText failed: %v", err) } if decoded != original { t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) } }) }) }) } }) t.Run("Invalid", func(t *testing.T) { t.Parallel() input := "invalid" var decoded logging.LogSink if err := decoded.UnmarshalText([]byte(input)); err == nil { t.Errorf("expected error for input %q, got none", input) } }) } func TestLogFormat(t *testing.T) { t.Parallel() t.Run("RoundTrip", func(t *testing.T) { t.Parallel() testCases := []logging.LogFormat{ logging.FormatText, logging.FormatJSON, } for _, original := range testCases { t.Run(original.String(), func(t *testing.T) { t.Run("case same", func(t *testing.T) { t.Run(original.String(), func(t *testing.T) { text, err := original.MarshalText() if err != nil { t.Fatalf("MarshalText failed: %v", err) } var decoded logging.LogFormat if err := decoded.UnmarshalText(text); err != nil { t.Fatalf("UnmarshalText failed: %v", err) } if decoded != original { t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) } }) }) t.Run("case not same", func(t *testing.T) { t.Run(original.String(), func(t *testing.T) { text, err := original.MarshalText() if err != nil { t.Fatalf("MarshalText failed: %v", err) } var decoded logging.LogFormat if err := decoded.UnmarshalText([]byte(strings.ToUpper(string(text)))); err != nil { t.Fatalf("UnmarshalText failed: %v", err) } if decoded != original { t.Fatalf("Round-trip mismatch: got %v, want %v", decoded, original) } }) }) }) } }) t.Run("Invalid", func(t *testing.T) { t.Parallel() input := "invalid" var decoded logging.LogFormat if err := decoded.UnmarshalText([]byte(input)); err == nil { t.Errorf("expected error for input %q, got none", input) } }) }