summaryrefslogtreecommitdiffstats
path: root/logging/test/handler.go
blob: d1f5d40b808d63a6fc90226cb502585e2147558f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package test

// Package test provides mocks and helpers for testing code that uses slog logging.
//
// It includes MockHandler, a thread-safe slog.Handler implementation for capturing log records in tests.

import (
	"context"
	"log/slog"
	"slices"
	"sync"
	"testing"
)

type logRecorder struct {
	mu      sync.Mutex
	records []slog.Record
}

func (h *logRecorder) Append(r slog.Record) {
	h.mu.Lock()
	defer h.mu.Unlock()

	h.records = append(h.records, r.Clone())
}

func (h *logRecorder) Records() []slog.Record {
	h.mu.Lock()
	defer h.mu.Unlock()
	return slices.Clone(h.records)
}

// MockHandler is a slog.Handler that records log records for testing.
//
// It supports attribute and group chaining like slog's built-in handlers.
// Use NewMockLogHandler to construct one, then pass it to slog.New in your tests.
//
// All recorded logs can be retrieved with the Records method.
type MockHandler struct {
	recorder *logRecorder
	parent   *MockHandler
	group    string
	attrs    []slog.Attr
}

var _ slog.Handler = &MockHandler{}

// NewMockLogHandler creates a new MockHandler for use in tests.
func NewMockLogHandler(tb testing.TB) *MockHandler {
	tb.Helper()

	return &MockHandler{
		recorder: &logRecorder{},
	}
}

func (h *MockHandler) Enabled(ctx context.Context, level slog.Level) bool {
	return true // Capture all logs
}

func (h *MockHandler) Handle(ctx context.Context, r slog.Record) error {
	if h.parent == nil {
		r.Clone()
		h.recorder.Append(r)
		return nil
	}

	newRecord := slog.NewRecord(r.Time, r.Level, r.Message, r.PC)

	attrs := slices.Clone(h.attrs)
	r.Attrs(func(attr slog.Attr) bool {
		attrs = append(attrs, attr)
		return true
	})

	if h.group != "" {
		newRecord.AddAttrs(slog.Attr{
			Key:   h.group,
			Value: slog.GroupValue(attrs...),
		})
	} else {
		newRecord.AddAttrs(attrs...)
	}

	return h.parent.Handle(ctx, newRecord)
}

func (h *MockHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
	return &MockHandler{
		recorder: h.recorder,
		parent:   h,
		attrs:    attrs,
	}
}

func (h *MockHandler) WithGroup(name string) slog.Handler {
	if name == "" {
		return h
	}
	return &MockHandler{
		recorder: h.recorder,
		parent:   h,
		group:    name,
	}
}

// Records returns a copy of all slog.Records captured by this handler so far.
//
// It is safe to call from multiple goroutines.
func (h *MockHandler) Records() []slog.Record {
	return h.recorder.Records()
}