diff options
author | Marc Pervaz Boocha <marcpervaz@qburst.com> | 2025-02-25 09:47:02 +0530 |
---|---|---|
committer | Marc Pervaz Boocha <marcpervaz@qburst.com> | 2025-02-27 13:38:14 +0530 |
commit | a5194b3be4768d42b27ef9e2d0d52479e0436758 (patch) | |
tree | 4e60b1b790e63b8aaf8dba77acbfe4599dedbee0 | |
parent | Add additional test cases to improve coverage and robustness (diff) | |
download | cache-a5194b3be4768d42b27ef9e2d0d52479e0436758.tar cache-a5194b3be4768d42b27ef9e2d0d52479e0436758.tar.gz cache-a5194b3be4768d42b27ef9e2d0d52479e0436758.tar.bz2 cache-a5194b3be4768d42b27ef9e2d0d52479e0436758.tar.lz cache-a5194b3be4768d42b27ef9e2d0d52479e0436758.tar.xz cache-a5194b3be4768d42b27ef9e2d0d52479e0436758.tar.zst cache-a5194b3be4768d42b27ef9e2d0d52479e0436758.zip |
Removed testify and improved tests
-rw-r--r-- | conn.go | 72 | ||||
-rw-r--r-- | conn_test.go | 128 | ||||
-rw-r--r-- | encoding.go | 22 | ||||
-rw-r--r-- | encoding_test.go | 293 | ||||
-rw-r--r-- | evict.go | 64 | ||||
-rw-r--r-- | evict_test.go | 200 | ||||
-rw-r--r-- | examples/basic_usage/main.go | 2 | ||||
-rw-r--r-- | examples/eviction_policy/main.go | 4 | ||||
-rw-r--r-- | go.mod | 10 | ||||
-rw-r--r-- | go.sum | 18 | ||||
-rw-r--r-- | internal/pausedtimer/timer.go | 2 | ||||
-rw-r--r-- | internal/pausedtimer/timer_test.go | 70 | ||||
-rw-r--r-- | store.go | 66 | ||||
-rw-r--r-- | store_test.go | 101 | ||||
-rw-r--r-- | utils.go | 2 |
15 files changed, 672 insertions, 382 deletions
@@ -11,7 +11,7 @@ import ( "github.com/vmihailenco/msgpack/v5" ) - // db represents a cache database with file-backed storage and in-memory operation. +// db represents a cache database with file-backed storage and in-memory operation. type db struct { File io.WriteSeeker Store store @@ -19,23 +19,26 @@ type db struct { wg sync.WaitGroup } - // Option is a function type for configuring the db. +// Option is a function type for configuring the db. type Option func(*db) error - // openFile opens a file-backed cache database with the given options. +// openFile opens a file-backed cache database with the given options. func openFile(filename string, options ...Option) (*db, error) { ret, err := openMem(options...) if err != nil { return nil, err } - file, err := lockedfile.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0666) + + file, err := lockedfile.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0o666) if err != nil { return nil, err } + fileInfo, err := file.Stat() if err != nil { return nil, err } + if fileInfo.Size() == 0 { ret.File = file ret.Flush() @@ -44,28 +47,31 @@ func openFile(filename string, options ...Option) (*db, error) { if err != nil { return nil, err } + ret.File = file } return ret, nil } - // openMem initializes an in-memory cache database with the given options. +// openMem initializes an in-memory cache database with the given options. func openMem(options ...Option) (*db, error) { ret := &db{} ret.Store.Init() ret.SetConfig(options...) + return ret, nil } - // Start begins the background worker for periodic tasks. +// Start begins the background worker for periodic tasks. func (d *db) Start() { d.Stop = make(chan struct{}) d.wg.Add(1) + go d.backgroundWorker() } - // SetConfig applies configuration options to the db. +// SetConfig applies configuration options to the db. func (d *db) SetConfig(options ...Option) error { d.Store.mu.Lock() defer d.Store.mu.Unlock() @@ -75,41 +81,45 @@ func (d *db) SetConfig(options ...Option) error { return err } } + return nil } - // WithPolicy sets the eviction policy for the cache. +// WithPolicy sets the eviction policy for the cache. func WithPolicy(e EvictionPolicyType) Option { return func(d *db) error { return d.Store.Policy.SetPolicy(e) } } - // WithMaxCost sets the maximum cost for the cache. +// WithMaxCost sets the maximum cost for the cache. func WithMaxCost(maxCost uint64) Option { return func(d *db) error { d.Store.MaxCost = maxCost + return nil } } - // SetSnapshotTime sets the interval for taking snapshots of the cache. +// SetSnapshotTime sets the interval for taking snapshots of the cache. func SetSnapshotTime(t time.Duration) Option { return func(d *db) error { d.Store.SnapshotTicker.Reset(t) + return nil } } - // SetCleanupTime sets the interval for cleaning up expired entries. +// SetCleanupTime sets the interval for cleaning up expired entries. func SetCleanupTime(t time.Duration) Option { return func(d *db) error { d.Store.CleanupTicker.Reset(t) + return nil } } - // backgroundWorker performs periodic tasks such as snapshotting and cleanup. +// backgroundWorker performs periodic tasks such as snapshotting and cleanup. func (d *db) backgroundWorker() { defer d.wg.Done() @@ -132,12 +142,13 @@ func (d *db) backgroundWorker() { } } - // Close stops the background worker and cleans up resources. +// Close stops the background worker and cleans up resources. func (d *db) Close() { close(d.Stop) d.wg.Wait() d.Flush() d.Clear() + if d.File != nil { closer, ok := d.File.(io.Closer) if ok { @@ -146,15 +157,16 @@ func (d *db) Close() { } } - // Flush writes the current state of the store to the file. +// Flush writes the current state of the store to the file. func (d *db) Flush() error { if d.File != nil { return d.Store.Snapshot(d.File) } + return nil } - // Clear removes all entries from the in-memory store. +// Clear removes all entries from the in-memory store. func (d *db) Clear() { d.Store.Clear() } @@ -162,42 +174,46 @@ func (d *db) Clear() { var ErrKeyNotFound = errors.New("key not found") // ErrKeyNotFound is returned when a key is not found in the cache. // The Cache database. Can be initialized by either OpenFile or OpenMem. Uses per DB Locks. - // DB represents a generic cache database with key-value pairs. +// DB represents a generic cache database with key-value pairs. type DB[K any, V any] struct { *db } - // OpenFile opens a file-backed cache database with the specified options. +// OpenFile opens a file-backed cache database with the specified options. func OpenFile[K any, V any](filename string, options ...Option) (DB[K, V], error) { ret, err := openFile(filename, options...) if err != nil { return zero[DB[K, V]](), err } + ret.Start() + return DB[K, V]{db: ret}, nil } - // OpenMem initializes an in-memory cache database with the specified options. +// OpenMem initializes an in-memory cache database with the specified options. func OpenMem[K any, V any](options ...Option) (DB[K, V], error) { ret, err := openMem(options...) if err != nil { return zero[DB[K, V]](), err } + ret.Start() + return DB[K, V]{db: ret}, nil } - // marshal serializes a value using msgpack. +// marshal serializes a value using msgpack. func marshal[T any](v T) ([]byte, error) { return msgpack.Marshal(v) } - // unmarshal deserializes data into a value using msgpack. +// unmarshal deserializes data into a value using msgpack. func unmarshal[T any](data []byte, v *T) error { return msgpack.Unmarshal(data, v) } - // Get retrieves a value from the cache by key and returns its TTL. +// Get retrieves a value from the cache by key and returns its TTL. func (h *DB[K, V]) Get(key K, value *V) (time.Duration, error) { keyData, err := marshal(key) if err != nil { @@ -208,44 +224,52 @@ func (h *DB[K, V]) Get(key K, value *V) (time.Duration, error) { if !ok { return 0, ErrKeyNotFound } + if v != nil { if err = unmarshal(v, value); err != nil { return 0, err } } + return ttl, err } - // GetValue retrieves a value from the cache by key and returns the value and its TTL. +// GetValue retrieves a value from the cache by key and returns the value and its TTL. func (h *DB[K, V]) GetValue(key K) (V, time.Duration, error) { value := zero[V]() ttl, err := h.Get(key, &value) + return value, ttl, err } - // Set adds a key-value pair to the cache with a specified TTL. +// Set adds a key-value pair to the cache with a specified TTL. func (h *DB[K, V]) Set(key K, value V, ttl time.Duration) error { keyData, err := marshal(key) if err != nil { return err } + valueData, err := marshal(value) if err != nil { return err } + h.Store.Set(keyData, valueData, ttl) + return nil } - // Delete removes a key-value pair from the cache. +// Delete removes a key-value pair from the cache. func (h *DB[K, V]) Delete(key K) error { keyData, err := marshal(key) if err != nil { return err } + ok := h.Store.Delete(keyData) if !ok { return ErrKeyNotFound } + return nil } diff --git a/conn_test.go b/conn_test.go index d6c284e..a97a5e9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,39 +1,25 @@ package cache import ( - "fmt" + "strconv" "testing" "time" - "github.com/stretchr/testify/assert" + "errors" ) -func setupTestDB[K any, V any](t testing.TB) *DB[K, V] { - t.Helper() +func setupTestDB[K any, V any](tb testing.TB) *DB[K, V] { + tb.Helper() db, err := OpenMem[K, V]() - assert.NoError(t, err) - t.Cleanup(func() { + if err != nil { + tb.Fatalf("unexpected error: %v", err) + } + tb.Cleanup(func() { db.Close() }) - return &db -func TestDBConcurrentAccess(t *testing.T) { - db := setupTestDB[string, string](t) - - go func() { - for i := 0; i < 100; i++ { - db.Set(fmt.Sprintf("Key%d", i), "Value", 0) - } - }() - go func() { - for i := 0; i < 100; i++ { - db.GetValue(fmt.Sprintf("Key%d", i)) - } - }() - - // Allow some time for goroutines to complete - time.Sleep(1 * time.Second) + return &db } func TestDBGetSet(t *testing.T) { @@ -45,15 +31,22 @@ func TestDBGetSet(t *testing.T) { db := setupTestDB[string, string](t) want := "Value" - err := db.Set("Key", want, 1*time.Hour) - assert.NoError(t, err) + + if err := db.Set("Key", want, 1*time.Hour); err != nil { + t.Fatalf("unexpected error: %v", err) + } got, ttl, err := db.GetValue("Key") - assert.NoError(t, err) - assert.Equal(t, want, got) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if want != got { + t.Fatalf("expected: %v, got: %v", want, got) + } - now := time.Now() - assert.WithinDuration(t, now.Add(ttl), now.Add(1*time.Hour), 1*time.Millisecond) + if ttl.Round(time.Second) != 1*time.Hour { + t.Fatalf("expected duration %v, got: %v", time.Hour, ttl.Round(time.Second)) + } }) t.Run("Not Exists", func(t *testing.T) { @@ -61,8 +54,9 @@ func TestDBGetSet(t *testing.T) { db := setupTestDB[string, string](t) - _, _, err := db.GetValue("Key") - assert.ErrorIs(t, err, ErrKeyNotFound) + if _, _, err := db.GetValue("Key"); !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("expected error: %v, got: %v", ErrKeyNotFound, err) + } }) t.Run("Update", func(t *testing.T) { @@ -70,16 +64,22 @@ func TestDBGetSet(t *testing.T) { db := setupTestDB[string, string](t) - err := db.Set("Key", "Other", 0) - assert.NoError(t, err) + if err := db.Set("Key", "Other", 0); err != nil { + t.Fatalf("expected no error, got: %v", err) + } want := "Value" - err = db.Set("Key", want, 0) - assert.NoError(t, err) + if err := db.Set("Key", want, 0); err != nil { + t.Fatalf("unexpected error: %v", err) + } got, _, err := db.GetValue("Key") - assert.NoError(t, err) - assert.Equal(t, want, got) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if want != got { + t.Fatalf("expected: %v, got: %v", want, got) + } }) t.Run("Key Expiry", func(t *testing.T) { @@ -87,13 +87,15 @@ func TestDBGetSet(t *testing.T) { db := setupTestDB[string, string](t) - err := db.Set("Key", "Value", 500*time.Millisecond) - assert.NoError(t, err) + if err := db.Set("Key", "Value", 500*time.Millisecond); err != nil { + t.Fatalf("unexpected error: %v", err) + } time.Sleep(600 * time.Millisecond) - _, _, err = db.GetValue("Key") - assert.ErrorIs(t, err, ErrKeyNotFound) + if _, _, err := db.GetValue("Key"); !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("expected error: %v, got: %v", ErrKeyNotFound, err) + } }) } @@ -105,14 +107,17 @@ func TestDBDelete(t *testing.T) { db := setupTestDB[string, string](t) want := "Value" - err := db.Set("Key", want, 0) - assert.NoError(t, err) + if err := db.Set("Key", want, 0); err != nil { + t.Fatalf("unexpected error: %v", err) + } - err = db.Delete("Key") - assert.NoError(t, err) + if err := db.Delete("Key"); err != nil { + t.Fatalf("unexpected error: %v", err) + } - _, _, err = db.GetValue("Key") - assert.ErrorIs(t, err, ErrKeyNotFound) + if _, _, err := db.GetValue("Key"); !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("expected error: %v, got: %v", ErrKeyNotFound, err) + } }) t.Run("Not Exists", func(t *testing.T) { @@ -120,22 +125,25 @@ func TestDBDelete(t *testing.T) { db := setupTestDB[string, string](t) - err := db.Delete("Key") - assert.ErrorIs(t, err, ErrKeyNotFound) + if err := db.Delete("Key"); !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("expected error: %v, got: %v", ErrKeyNotFound, err) + } }) } func BenchmarkDBGet(b *testing.B) { for n := 1; n <= 10000; n *= 10 { - b.Run(fmt.Sprint(n), func(b *testing.B) { + b.Run(strconv.Itoa(n), func(b *testing.B) { db := setupTestDB[int, int](b) - for i := 0; i < n; i++ { + for i := range n { db.Set(i, i, 0) } + b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + + for b.Loop() { db.GetValue(n - 1) } }) @@ -144,14 +152,16 @@ func BenchmarkDBGet(b *testing.B) { func BenchmarkDBSet(b *testing.B) { for n := 1; n <= 10000; n *= 10 { - b.Run(fmt.Sprint(n), func(b *testing.B) { + b.Run(strconv.Itoa(n), func(b *testing.B) { db := setupTestDB[int, int](b) - for i := 0; i < n-1; i++ { + for i := range n - 1 { db.Set(i, i, 0) } + b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + + for b.Loop() { db.Set(n, n, 0) } }) @@ -160,14 +170,16 @@ func BenchmarkDBSet(b *testing.B) { func BenchmarkDBDelete(b *testing.B) { for n := 1; n <= 10000; n *= 10 { - b.Run(fmt.Sprint(n), func(b *testing.B) { + b.Run(strconv.Itoa(n), func(b *testing.B) { db := setupTestDB[int, int](b) - for i := 0; i < n-1; i++ { + for i := range n - 1 { db.Set(i, i, 0) } + b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + + for b.Loop() { db.Set(n, n, 0) db.Delete(n) } diff --git a/encoding.go b/encoding.go index bc68d2f..1327087 100644 --- a/encoding.go +++ b/encoding.go @@ -26,6 +26,7 @@ func (e *encoder) Flush() error { func (e *encoder) EncodeUint64(val uint64) error { binary.LittleEndian.PutUint64(e.buf, val) _, err := e.w.Write(e.buf) + return err } @@ -39,6 +40,7 @@ func (e *encoder) EncodeBytes(val []byte) error { } _, err := e.w.Write(val) + return err } @@ -84,6 +86,7 @@ func (e *encoder) EncodeStore(s *store) error { return err } } + return nil } @@ -104,6 +107,7 @@ func (d *decoder) DecodeUint64() (uint64, error) { if err != nil { return 0, err } + return binary.LittleEndian.Uint64(d.buf), nil } @@ -112,7 +116,12 @@ func (d *decoder) DecodeTime() (time.Time, error) { if err != nil { return zero[time.Time](), err } - return time.Unix(int64(ts), 0), nil + + t := time.Unix(int64(ts), 0) + if t.IsZero() { + t = zero[time.Time]() + } + return t, nil } func (d *decoder) DecodeBytes() ([]byte, error) { @@ -120,8 +129,10 @@ func (d *decoder) DecodeBytes() ([]byte, error) { if err != nil { return nil, err } + data := make([]byte, lenVal) _, err = io.ReadFull(d.r, data) + return data, err } @@ -132,18 +143,21 @@ func (d *decoder) DecodeNodes() (*node, error) { if err != nil { return nil, err } + n.Hash = hash expiration, err := d.DecodeTime() if err != nil { return nil, err } + n.Expiration = expiration access, err := d.DecodeUint64() if err != nil { return nil, err } + n.Access = access n.Key, err = d.DecodeBytes() @@ -155,6 +169,7 @@ func (d *decoder) DecodeNodes() (*node, error) { if err != nil { return nil, err } + return n, err } @@ -163,18 +178,21 @@ func (d *decoder) DecodeStore(s *store) error { if err != nil { return err } + s.MaxCost = maxCost policy, err := d.DecodeUint64() if err != nil { return err } + s.Policy.SetPolicy(EvictionPolicyType(policy)) length, err := d.DecodeUint64() if err != nil { return err } + s.Length = length k := 128 @@ -206,6 +224,7 @@ func (d *decoder) DecodeStore(s *store) error { s.Cost = s.Cost + uint64(len(v.Key)) + uint64(len(v.Value)) } + return nil } @@ -227,6 +246,7 @@ func (s *store) LoadSnapshot(r io.ReadSeeker) error { if _, err := r.Seek(0, io.SeekStart); err != nil { return err } + d := newDecoder(r) return d.DecodeStore(s) diff --git a/encoding_test.go b/encoding_test.go index bac148c..4f3af4f 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -5,13 +5,22 @@ import ( "encoding/binary" "fmt" "os" + "strconv" "testing" "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) +func TestDecodeUint64Error(t *testing.T) { + buf := bytes.NewReader([]byte{0xFF}) + + decoder := newDecoder(buf) + + _, err := decoder.DecodeUint64() + if err == nil { + t.Errorf("expected an error but got none") + } +} + func TestEncodeDecodeUint64(t *testing.T) { tests := []struct { name string @@ -28,52 +37,99 @@ func TestEncodeDecodeUint64(t *testing.T) { e := newEncoder(&buf) err := e.EncodeUint64(tt.value) - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } err = e.Flush() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } decoder := newDecoder(bytes.NewReader(buf.Bytes())) decodedValue, err := decoder.DecodeUint64() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } - assert.Equal(t, tt.value, decodedValue) + if tt.value != decodedValue { + t.Errorf("expected %v, got %v", tt.value, decodedValue) + } }) } -func TestDecodeUint64Error(t *testing.T) { - var buf bytes.Buffer - buf.Write([]byte{0xFF}) // Invalid data for uint64 - decoder := newDecoder(&buf) +} + +func TestEncodeDecodeStoreWithPolicies(t *testing.T) { + policies := []EvictionPolicyType{PolicyFIFO, PolicyLRU, PolicyLFU, PolicyLTR} + + for _, policy := range policies { + t.Run(fmt.Sprintf("Policy_%d", policy), func(t *testing.T) { + var buf bytes.Buffer + e := newEncoder(&buf) + + store := setupTestStore(t) + store.Policy.SetPolicy(policy) + + err := e.EncodeStore(store) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + err = e.Flush() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + decoder := newDecoder(bytes.NewReader(buf.Bytes())) + decodedStore := setupTestStore(t) + + err = decoder.DecodeStore(decodedStore) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if store.Policy.Type != decodedStore.Policy.Type { + t.Errorf("expected %v, got %v", store.Policy.Type, decodedStore.Policy.Type) + } + }) + } +} - _, err := decoder.DecodeUint64() - assert.Error(t, err) func TestEncodeDecodeTimeBoundary(t *testing.T) { - tests := []struct { - name string - value time.Time - }{ - {name: "Unix Epoch", value: time.Unix(0, 0)}, - {name: "Far Future", value: time.Unix(1<<63-1, 0)}, - } + tests := []struct { + name string + value time.Time + }{ + {name: "Unix Epoch", value: time.Unix(0, 0)}, + {name: "Far Future", value: time.Unix(1<<63-1, 0)}, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - e := newEncoder(&buf) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + e := newEncoder(&buf) + + err := e.EncodeTime(tt.value) - err := e.EncodeTime(tt.value) - assert.NoError(t, err) - err = e.Flush() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if err != nil { + t.Errorf("unexpected error: %v", err) + } + err = e.Flush() - decoder := newDecoder(bytes.NewReader(buf.Bytes())) + decoder := newDecoder(bytes.NewReader(buf.Bytes())) - decodedValue, err := decoder.DecodeTime() - assert.NoError(t, err) + decodedValue, err := decoder.DecodeTime() + if err != nil { + t.Errorf("unexpected error: %v", err) + } - assert.Equal(t, tt.value, decodedValue) - }) - } + if tt.value != decodedValue { + t.Errorf("expected %v, got %v", tt.value, decodedValue) + } + }) + } } func TestEncodeDecodeTime(t *testing.T) { @@ -82,6 +138,7 @@ func TestEncodeDecodeTime(t *testing.T) { value time.Time }{ {name: "Time Now", value: time.Now()}, + {name: "Unix Epoch", value: time.Unix(0, 0)}, {name: "Time Zero", value: time.Time{}}, } @@ -91,20 +148,50 @@ func TestEncodeDecodeTime(t *testing.T) { e := newEncoder(&buf) err := e.EncodeTime(tt.value) - assert.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } err = e.Flush() - assert.NoError(t, err) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } decoder := newDecoder(bytes.NewReader(buf.Bytes())) decodedValue, err := decoder.DecodeTime() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if tt.value.Unix() != decodedValue.Unix() { + t.Errorf("expected %v, got %v", tt.value, decodedValue) + } - assert.WithinDuration(t, tt.value, decodedValue, time.Second) }) } } +func TestDecodeBytesError(t *testing.T) { + var buf bytes.Buffer + e := newEncoder(&buf) + + err := e.EncodeBytes([]byte("DEADBEEF")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + err = e.Flush() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + decoder := newDecoder(bytes.NewReader(buf.Bytes()[:10])) + + _, err = decoder.DecodeBytes() + if err == nil { + t.Errorf("expected an error but got none") + } +} + func TestEncodeDecodeBytes(t *testing.T) { tests := []struct { name string @@ -121,16 +208,24 @@ func TestEncodeDecodeBytes(t *testing.T) { e := newEncoder(&buf) err := e.EncodeBytes(tt.value) - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } err = e.Flush() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } decoder := newDecoder(bytes.NewReader(buf.Bytes())) decodedValue, err := decoder.DecodeBytes() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } - assert.Equal(t, tt.value, decodedValue) + if !bytes.Equal(tt.value, decodedValue) { + t.Errorf("expected %v, got %v", tt.value, decodedValue) + } }) } } @@ -178,20 +273,36 @@ func TestEncodeDecodeNode(t *testing.T) { e := newEncoder(&buf) err := e.EncodeNode(tt.value) - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } err = e.Flush() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } decoder := newDecoder(bytes.NewReader(buf.Bytes())) decodedValue, err := decoder.DecodeNodes() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } - assert.Equal(t, tt.value.Hash, decodedValue.Hash) - assert.WithinDuration(t, tt.value.Expiration, decodedValue.Expiration, 1*time.Second) - assert.Equal(t, tt.value.Access, decodedValue.Access) - assert.Equal(t, tt.value.Key, decodedValue.Key) - assert.Equal(t, tt.value.Value, decodedValue.Value) + if tt.value.Hash != decodedValue.Hash { + t.Errorf("expected %v, got %v", tt.value.Hash, decodedValue.Hash) + } + if !tt.value.Expiration.Equal(decodedValue.Expiration) && tt.value.Expiration.Sub(decodedValue.Expiration) > time.Second { + t.Errorf("expected %v to be within %v of %v", decodedValue.Expiration, time.Second, tt.value.Expiration) + } + if tt.value.Access != decodedValue.Access { + t.Errorf("expected %v, got %v", tt.value.Access, decodedValue.Access) + } + if !bytes.Equal(tt.value.Key, decodedValue.Key) { + t.Errorf("expected %v, got %v", tt.value.Key, decodedValue.Key) + } + if !bytes.Equal(tt.value.Value, decodedValue.Value) { + t.Errorf("expected %v, got %v", tt.value.Value, decodedValue.Value) + } }) } } @@ -238,55 +349,72 @@ func TestEncodeDecodeStrorage(t *testing.T) { want := setupTestStore(t) want.MaxCost = uint64(tt.maxCost) err := want.Policy.SetPolicy(tt.policy) - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } for k, v := range tt.store { want.Set([]byte(k), []byte(v), 0) } err = e.EncodeStore(want) - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } err = e.Flush() - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } decoder := newDecoder(bytes.NewReader(buf.Bytes())) got := setupTestStore(t) err = decoder.DecodeStore(got) - assert.NoError(t, err) + if err != nil { + t.Errorf("unexpected error: %v", err) + } - assert.Equal(t, want.MaxCost, got.MaxCost) - assert.Equal(t, want.Length, got.Length) - assert.Equal(t, want.Policy.Type, got.Policy.Type) + if want.MaxCost != got.MaxCost { + t.Errorf("expected %v, got %v", want.MaxCost, got.MaxCost) + } + if want.Length != got.Length { + t.Errorf("expected %v, got %v", want.Length, got.Length) + } + if want.Policy.Type != got.Policy.Type { + t.Errorf("expected %v, got %v", want.Policy.Type, got.Policy.Type) + } gotOrder := getListOrder(t, &got.Evict) for i, v := range getListOrder(t, &want.Evict) { - assert.Equal(t, v.Key, gotOrder[i].Key) + if !bytes.Equal(v.Key, gotOrder[i].Key) { + t.Errorf("expected %#v, got %#v", v.Key, gotOrder[i].Key) + } } for k, v := range tt.store { gotVal, _, ok := want.Get([]byte(k)) - require.True(t, ok) - require.Equal(t, []byte(v), gotVal) + if !ok { + t.Fatalf("expected condition to be true") + } + if !bytes.Equal([]byte(v), gotVal) { + t.Fatalf("expected %v, got %v", []byte(v), gotVal) + } } }) } } -type MockSeeker struct { - *bytes.Buffer -} - func BenchmarkEncoder_EncodeStore(b *testing.B) { - file, err := os.CreateTemp("", "benchmark_test_") + file, err := os.CreateTemp(b.TempDir(), "benchmark_test_") if err != nil { b.Fatal(err) } + defer os.Remove(file.Name()) defer file.Close() for n := 1; n <= 10000; n *= 10 { - b.Run(fmt.Sprint(n), func(b *testing.B) { + b.Run(strconv.Itoa(n), func(b *testing.B) { want := setupTestStore(b) for i := range n { @@ -296,32 +424,40 @@ func BenchmarkEncoder_EncodeStore(b *testing.B) { } err = want.Snapshot(file) - require.NoError(b, err) + if err != nil { + b.Fatalf("unexpected error: %v", err) + } fileInfo, err := file.Stat() - require.NoError(b, err) + if err != nil { + b.Fatalf("unexpected error: %v", err) + } b.SetBytes(int64(fileInfo.Size())) b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + + for b.Loop() { want.Snapshot(file) } }) } - } func BenchmarkDecoder_DecodeStore(b *testing.B) { + file, err := os.CreateTemp(b.TempDir(), "benchmark_test_") + + if err != nil { + b.Errorf("unexpected error: %v", err) + } - file, err := os.CreateTemp("", "benchmark_test_") - require.NoError(b, err) defer os.Remove(file.Name()) defer file.Close() for n := 1; n <= 10000; n *= 10 { - b.Run(fmt.Sprint(n), func(b *testing.B) { + b.Run(strconv.Itoa(n), func(b *testing.B) { want := setupTestStore(b) + for i := range n { buf := make([]byte, 8) binary.LittleEndian.PutUint64(buf, uint64(i)) @@ -329,14 +465,19 @@ func BenchmarkDecoder_DecodeStore(b *testing.B) { } err = want.Snapshot(file) - require.NoError(b, err) + if err != nil { + b.Fatalf("unexpected error: %v", err) + } fileInfo, err := file.Stat() - require.NoError(b, err) + if err != nil { + b.Fatalf("unexpected error: %v", err) + } b.SetBytes(int64(fileInfo.Size())) b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + + for b.Loop() { want.LoadSnapshot(file) } }) @@ -7,8 +7,8 @@ import ( // EvictionPolicyType defines the type of eviction policy. type EvictionPolicyType int -const ( - // PolicyNone indicates no eviction policy. +const ( + // PolicyNone indicates no eviction policy. PolicyNone EvictionPolicyType = iota PolicyFIFO PolicyLRU @@ -17,7 +17,6 @@ const ( ) // evictionStrategies interface defines the methods for eviction strategies. - // evictionStrategies interface defines the methods for eviction strategies. type evictionStrategies interface { OnInsert(n *node) OnUpdate(n *node) @@ -26,14 +25,13 @@ type evictionStrategies interface { } // evictionPolicy struct holds the eviction strategy and its type. - // evictionPolicy struct holds the eviction strategy and its type. type evictionPolicy struct { evictionStrategies Type EvictionPolicyType evict *node } - // pushEvict adds a node to the eviction list. +// pushEvict adds a node to the eviction list. func pushEvict(node *node, sentinnel *node) { node.EvictPrev = sentinnel node.EvictNext = node.EvictPrev.EvictNext @@ -42,7 +40,6 @@ func pushEvict(node *node, sentinnel *node) { } // SetPolicy sets the eviction policy based on the given type. - // SetPolicy sets the eviction policy based on the given type. func (e *evictionPolicy) SetPolicy(y EvictionPolicyType) error { store := map[EvictionPolicyType]func() evictionStrategies{ PolicyNone: func() evictionStrategies { @@ -61,39 +58,41 @@ func (e *evictionPolicy) SetPolicy(y EvictionPolicyType) error { return ltrPolicy{evict: e.evict} }, } + factory, ok := store[y] if !ok { return errors.New("invalid policy") } + e.evictionStrategies = factory() + return nil } -// fifoPolicy struct represents the First-In-First-Out eviction policy. - // fifoPolicy struct represents the First-In-First-Out eviction policy. +type evictOrderedPolicy interface { + evictionStrategies + getEvict() *node +} + type fifoPolicy struct { evict *node shouldEvict bool } // OnInsert adds a node to the eviction list. - // OnInsert adds a node to the eviction list. func (s fifoPolicy) OnInsert(node *node) { pushEvict(node, s.evict) } // OnAccess is a no-op for fifoPolicy. - // OnAccess is a no-op for fifoPolicy. func (fifoPolicy) OnAccess(n *node) { } // OnUpdate is a no-op for fifoPolicy. - // OnUpdate is a no-op for fifoPolicy. func (fifoPolicy) OnUpdate(n *node) { } // Evict returns the oldest node for fifoPolicy. - // Evict returns the oldest node for fifoPolicy. func (s fifoPolicy) Evict() *node { if s.shouldEvict && s.evict.EvictPrev != s.evict { return s.evict.EvictPrev @@ -102,25 +101,26 @@ func (s fifoPolicy) Evict() *node { } } +func (s fifoPolicy) getEvict() *node { + return s.evict +} + // lruPolicy struct represents the Least Recently Used eviction policy. - // lruPolicy struct represents the Least Recently Used eviction policy. type lruPolicy struct { evict *node } // OnInsert adds a node to the eviction list. - // OnInsert adds a node to the eviction list. func (s lruPolicy) OnInsert(node *node) { pushEvict(node, s.evict) } - // OnUpdate moves the accessed node to the front of the eviction list. +// OnUpdate moves the accessed node to the front of the eviction list. func (s lruPolicy) OnUpdate(node *node) { s.OnAccess(node) } // OnAccess moves the accessed node to the front of the eviction list. - // OnAccess moves the accessed node to the front of the eviction list. func (s lruPolicy) OnAccess(node *node) { node.EvictNext.EvictPrev = node.EvictPrev node.EvictPrev.EvictNext = node.EvictNext @@ -128,7 +128,6 @@ func (s lruPolicy) OnAccess(node *node) { } // Evict returns the least recently used node for lruPolicy. - // Evict returns the least recently used node for lruPolicy. func (s lruPolicy) Evict() *node { if s.evict.EvictPrev != s.evict { return s.evict.EvictPrev @@ -137,26 +136,26 @@ func (s lruPolicy) Evict() *node { } } +func (s lruPolicy) getEvict() *node { + return s.evict +} + // lfuPolicy struct represents the Least Frequently Used eviction policy. - // lfuPolicy struct represents the Least Frequently Used eviction policy. type lfuPolicy struct { evict *node } // OnInsert adds a node to the eviction list and initializes its access count. - // OnInsert adds a node to the eviction list and initializes its access count. func (s lfuPolicy) OnInsert(node *node) { pushEvict(node, s.evict) } // OnUpdate increments the access count of the node and reorders the list. - // OnUpdate increments the access count of the node and reorders the list. func (s lfuPolicy) OnUpdate(node *node) { s.OnAccess(node) } // OnAccess increments the access count of the node and reorders the list. - // OnAccess increments the access count of the node and reorders the list. func (s lfuPolicy) OnAccess(node *node) { node.Access++ @@ -169,9 +168,11 @@ func (s lfuPolicy) OnAccess(node *node) { node.EvictNext = node.EvictPrev.EvictNext node.EvictNext.EvictPrev = node node.EvictPrev.EvictNext = node + return } } + node.EvictNext.EvictPrev = node.EvictPrev node.EvictPrev.EvictNext = node.EvictNext @@ -182,7 +183,6 @@ func (s lfuPolicy) OnAccess(node *node) { } // Evict returns the least frequently used node for LFU. - // Evict returns the least frequently used node for LFU. func (s lfuPolicy) Evict() *node { if s.evict.EvictPrev != s.evict { return s.evict.EvictPrev @@ -191,8 +191,11 @@ func (s lfuPolicy) Evict() *node { } } +func (s ltrPolicy) getEvict() *node { + return s.evict +} + // ltrPolicy struct represents the Least Remaining Time eviction policy. - // ltrPolicy struct represents the Least Remaining Time eviction policy. type ltrPolicy struct { evict *node evictZero bool @@ -200,7 +203,6 @@ type ltrPolicy struct { // OnInsert adds a node to the eviction list based on its TTL (Time To Live). // It places the node in the correct position in the list based on TTL. - // OnInsert adds a node to the eviction list based on its TTL (Time To Live). func (s ltrPolicy) OnInsert(node *node) { pushEvict(node, s.evict) @@ -209,21 +211,21 @@ func (s ltrPolicy) OnInsert(node *node) { // OnAccess is a no-op for ltrPolicy. // It does not perform any action when a node is accessed. - // OnAccess is a no-op for ltrPolicy. func (s ltrPolicy) OnAccess(node *node) { } // OnUpdate updates the position of the node in the eviction list based on its TTL. // It reorders the list to maintain the correct order based on TTL. - // OnUpdate updates the position of the node in the eviction list based on its TTL. func (s ltrPolicy) OnUpdate(node *node) { if node.TTL() == 0 { return } + for v := node.EvictPrev; v.EvictPrev != s.evict; v = v.EvictPrev { if v.TTL() == 0 { continue } + if v.TTL() < node.TTL() { node.EvictNext.EvictPrev = node.EvictPrev node.EvictPrev.EvictNext = node.EvictNext @@ -232,13 +234,16 @@ func (s ltrPolicy) OnUpdate(node *node) { node.EvictNext = node.EvictPrev.EvictNext node.EvictNext.EvictPrev = node node.EvictPrev.EvictNext = node + return } } + for v := node.EvictNext; v.EvictNext != s.evict; v = v.EvictNext { if v.TTL() == 0 { continue } + if v.TTL() > node.TTL() { node.EvictNext.EvictPrev = node.EvictPrev node.EvictPrev.EvictNext = node.EvictNext @@ -247,6 +252,7 @@ func (s ltrPolicy) OnUpdate(node *node) { node.EvictNext = node.EvictPrev.EvictNext node.EvictNext.EvictPrev = node node.EvictPrev.EvictNext = node + return } } @@ -254,10 +260,14 @@ func (s ltrPolicy) OnUpdate(node *node) { // Evict returns the node with the least remaining time to live for ltrPolicy. // It returns the node at the end of the eviction list. - // Evict returns the node with the least remaining time to live for ltrPolicy. func (s ltrPolicy) Evict() *node { if s.evict.EvictPrev != s.evict && (s.evict.EvictPrev.TTL() != 0 || s.evictZero) { return s.evict.EvictPrev } + return nil } + +func (s lfuPolicy) getEvict() *node { + return s.evict +} diff --git a/evict_test.go b/evict_test.go index 1813e94..5938b51 100644 --- a/evict_test.go +++ b/evict_test.go @@ -3,33 +3,55 @@ package cache import ( "testing" "time" - - "github.com/stretchr/testify/assert" ) -func createSentinel(t testing.TB) *node { - t.Helper() +func createSentinel(tb testing.TB) *node { + tb.Helper() + n1 := &node{Key: []byte("Sentinel")} n1.EvictNext = n1 n1.EvictPrev = n1 + return n1 } -func getListOrder(t testing.TB, evict *node) []*node { - t.Helper() +func getListOrder(tb testing.TB, evict *node) []*node { + tb.Helper() var order []*node + current := evict.EvictNext for current != evict { order = append(order, current) current = current.EvictNext } + for _, n := range order { - assert.Same(t, n, n.EvictPrev.EvictNext) + tb.Helper() + if n != n.EvictPrev.EvictNext { + tb.Fatalf("expected %#v, got %#v", n, n.EvictPrev.EvictNext) + } } + return order } +func checkOrder(tb testing.TB, policy evictOrderedPolicy, expected []*node) { + tb.Helper() + + order := getListOrder(tb, policy.getEvict()) + + if len(order) != len(expected) { + tb.Errorf("expected length %v, got %v", len(expected), len(order)) + } + + for i, n := range expected { + if order[i] != n { + tb.Errorf("element %v did not match: \nexpected: %#v\n got: %#v", i, n, order[i]) + } + } +} + func TestFIFOPolicy(t *testing.T) { t.Parallel() @@ -41,13 +63,10 @@ func TestFIFOPolicy(t *testing.T) { n0 := &node{Key: []byte("0")} n1 := &node{Key: []byte("1")} - policy.OnInsert(n0) policy.OnInsert(n1) + policy.OnInsert(n0) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Same(t, order[0], n1) - assert.Same(t, order[1], n0) + checkOrder(t, policy, []*node{n0, n1}) }) t.Run("Evict", func(t *testing.T) { @@ -65,7 +84,9 @@ func TestFIFOPolicy(t *testing.T) { policy.OnInsert(n1) evictedNode := policy.Evict() - assert.Same(t, n0, evictedNode) + if n0 != evictedNode { + t.Errorf("expected %#v, got %#v", n0, evictedNode) + } }) t.Run("Evict noEvict", func(t *testing.T) { @@ -75,50 +96,25 @@ func TestFIFOPolicy(t *testing.T) { policy.OnInsert(&node{}) - assert.Nil(t, policy.Evict()) + if policy.Evict() != nil { + t.Errorf("expected nil, got %#v", policy.Evict()) + } }) t.Run("Empty List", func(t *testing.T) { t.Parallel() policy := fifoPolicy{evict: createSentinel(t)} - - assert.Nil(t, policy.Evict()) + if policy.Evict() != nil { + t.Errorf("expected nil, got %#v", policy.Evict()) + } }) }) - - t.Run("Eviction Order", func(t *testing.T) { - t.Parallel() - - policy := lfuPolicy{evict: createSentinel(t)} - - n0 := &node{Key: []byte("0"), Access: 1} - n1 := &node{Key: []byte("1"), Access: 1} - - policy.OnInsert(n0) - policy.OnInsert(n1) - - evictedNode := policy.Evict() - assert.Same(t, n0, evictedNode) // Assuming FIFO order for same access count - }) - - t.Run("With Zero TTL", func(t *testing.T) { - t.Parallel() - - policy := ltrPolicy{evict: createSentinel(t), evictZero: false} - - n0 := &node{Key: []byte("0"), Expiration: time.Time{}} - n1 := &node{Key: []byte("1"), Expiration: time.Now().Add(1 * time.Hour)} - - policy.OnInsert(n0) - policy.OnInsert(n1) - - evictedNode := policy.Evict() - assert.Same(t, n1, evictedNode) // n0 should not be evicted due to zero TTL - }) } func TestLRUPolicy(t *testing.T) { + t.Parallel() + t.Run("OnInsert", func(t *testing.T) { t.Parallel() @@ -130,10 +126,7 @@ func TestLRUPolicy(t *testing.T) { policy.OnInsert(n0) policy.OnInsert(n1) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Same(t, order[0], n1) - assert.Same(t, order[1], n0) + checkOrder(t, policy, []*node{n1, n0}) }) t.Run("OnAccess", func(t *testing.T) { @@ -149,10 +142,7 @@ func TestLRUPolicy(t *testing.T) { policy.OnAccess(n0) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Same(t, order[0], n0) - assert.Same(t, order[1], n1) + checkOrder(t, policy, []*node{n0, n1}) }) t.Run("Evict", func(t *testing.T) { @@ -170,7 +160,9 @@ func TestLRUPolicy(t *testing.T) { policy.OnInsert(n1) evictedNode := policy.Evict() - assert.Same(t, n0, evictedNode) + if n0 != evictedNode { + t.Errorf("expected %#v, got %#v", n0, evictedNode) + } }) t.Run("OnAccess End", func(t *testing.T) { @@ -187,7 +179,9 @@ func TestLRUPolicy(t *testing.T) { policy.OnAccess(n0) evictedNode := policy.Evict() - assert.Same(t, n1, evictedNode) + if n1 != evictedNode { + t.Errorf("expected %#v, got %#v", n1, evictedNode) + } }) t.Run("OnAccess Interleaved", func(t *testing.T) { @@ -203,15 +197,19 @@ func TestLRUPolicy(t *testing.T) { policy.OnInsert(n1) evictedNode := policy.Evict() - assert.Same(t, n0, evictedNode) + + if n0 != evictedNode { + t.Errorf("expected %#v, got %#v", n0, evictedNode) + } }) t.Run("Empty", func(t *testing.T) { t.Parallel() policy := lruPolicy{evict: createSentinel(t)} - - assert.Nil(t, policy.Evict()) + if policy.Evict() != nil { + t.Errorf("expected nil, got %#v", policy.Evict()) + } }) }) } @@ -230,10 +228,7 @@ func TestLFUPolicy(t *testing.T) { policy.OnInsert(n0) policy.OnInsert(n1) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Contains(t, order, n0) - assert.Contains(t, order, n1) + checkOrder(t, policy, []*node{n1, n0}) }) t.Run("OnAccess", func(t *testing.T) { @@ -249,10 +244,7 @@ func TestLFUPolicy(t *testing.T) { policy.OnAccess(n0) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Same(t, order[0], n0) - assert.Same(t, order[1], n1) + checkOrder(t, policy, []*node{n0, n1}) }) t.Run("Evict", func(t *testing.T) { @@ -272,7 +264,9 @@ func TestLFUPolicy(t *testing.T) { policy.OnAccess(n0) evictedNode := policy.Evict() - assert.Same(t, n1, evictedNode) + if n1 != evictedNode { + t.Errorf("expected %#v, got %#v", n1, evictedNode) + } }) t.Run("Evict After Multiple Accesses", func(t *testing.T) { @@ -292,15 +286,19 @@ func TestLFUPolicy(t *testing.T) { policy.OnAccess(n1) evictedNode := policy.Evict() - assert.Same(t, n0, evictedNode) + + if n0 != evictedNode { + t.Errorf("expected %#v, got %#v", n0, evictedNode) + } }) t.Run("Empty List", func(t *testing.T) { t.Parallel() policy := lfuPolicy{evict: createSentinel(t)} - - assert.Nil(t, policy.Evict()) + if policy.Evict() != nil { + t.Errorf("expected nil, got %#v", policy.Evict()) + } }) }) } @@ -322,10 +320,7 @@ func TestLTRPolicy(t *testing.T) { policy.OnInsert(n0) policy.OnInsert(n1) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Same(t, n0, order[0]) - assert.Same(t, n1, order[1]) + checkOrder(t, policy, []*node{n0, n1}) }) t.Run("Without TTL", func(t *testing.T) { @@ -339,10 +334,7 @@ func TestLTRPolicy(t *testing.T) { policy.OnInsert(n0) policy.OnInsert(n1) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Same(t, n1, order[0]) - assert.Same(t, n0, order[1]) + checkOrder(t, policy, []*node{n1, n0}) }) }) @@ -363,10 +355,7 @@ func TestLTRPolicy(t *testing.T) { n0.Expiration = time.Now().Add(3 * time.Hour) policy.OnUpdate(n0) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Same(t, n0, order[1]) - assert.Same(t, n1, order[0]) + checkOrder(t, policy, []*node{n1, n0}) }) t.Run("With TTL Decrease", func(t *testing.T) { @@ -383,10 +372,7 @@ func TestLTRPolicy(t *testing.T) { n1.Expiration = time.Now().Add(30 * time.Minute) policy.OnUpdate(n1) - order := getListOrder(t, policy.evict) - assert.Len(t, order, 2) - assert.Same(t, n1, order[1]) - assert.Same(t, n0, order[0]) + checkOrder(t, policy, []*node{n0, n1}) }) }) @@ -405,7 +391,25 @@ func TestLTRPolicy(t *testing.T) { policy.OnInsert(n1) evictedNode := policy.Evict() - assert.Same(t, n0, evictedNode) + if n0 != evictedNode { + t.Errorf("expected %#v, got %#v", n0, evictedNode) + } + }) + + t.Run("no evictZero", func(t *testing.T) { + t.Parallel() + + policy := ltrPolicy{evict: createSentinel(t), evictZero: false} + + n0 := &node{Key: []byte("0")} + n1 := &node{Key: []byte("1")} + + policy.OnInsert(n0) + policy.OnInsert(n1) + + if policy.Evict() != nil { + t.Errorf("expected nil, got %#v", policy.Evict()) + } }) t.Run("Evict TTL", func(t *testing.T) { @@ -420,7 +424,10 @@ func TestLTRPolicy(t *testing.T) { policy.OnInsert(n1) evictedNode := policy.Evict() - assert.Same(t, n1, evictedNode) + + if n1 != evictedNode { + t.Errorf("expected %#v, got %#v", n0, evictedNode) + } }) t.Run("Evict TTL Update", func(t *testing.T) { @@ -438,7 +445,10 @@ func TestLTRPolicy(t *testing.T) { policy.OnUpdate(n0) evictedNode := policy.Evict() - assert.Same(t, n0, evictedNode) + + if n0 != evictedNode { + t.Errorf("expected %#v, got %#v", n0, evictedNode) + } }) t.Run("Evict TTL Update Down", func(t *testing.T) { @@ -456,15 +466,19 @@ func TestLTRPolicy(t *testing.T) { policy.OnUpdate(n1) evictedNode := policy.Evict() - assert.Same(t, n1, evictedNode) + + if n1 != evictedNode { + t.Errorf("expected %#v, got %#v", n0, evictedNode) + } }) t.Run("Empty List", func(t *testing.T) { t.Parallel() policy := ltrPolicy{evict: createSentinel(t), evictZero: true} - - assert.Nil(t, policy.Evict()) + if policy.Evict() != nil { + t.Errorf("expected nil, got %#v", policy.Evict()) + } }) }) } diff --git a/examples/basic_usage/main.go b/examples/basic_usage/main.go index 0582770..73e88ea 100644 --- a/examples/basic_usage/main.go +++ b/examples/basic_usage/main.go @@ -9,7 +9,7 @@ import ( func main() { // Create an in-memory cache - db, err := cache.OpenMem[string, string]("example") + db, err := cache.OpenMem[string, string]() if err != nil { fmt.Println("Error:", err) return diff --git a/examples/eviction_policy/main.go b/examples/eviction_policy/main.go index a9c8577..5b5599a 100644 --- a/examples/eviction_policy/main.go +++ b/examples/eviction_policy/main.go @@ -3,12 +3,12 @@ package main import ( "fmt" - "code.qburst.com/marcpervaz//cache" + "github.com/marcthe12/cache" ) func main() { // Create an in-memory cache with LRU eviction policy - db, err := cache.OpenMem[string, string]("example", cache.WithPolicy(cache.PolicyLRU)) + db, err := cache.OpenMem[string, string](cache.WithPolicy(cache.PolicyLRU)) if err != nil { fmt.Println("Error:", err) return @@ -3,14 +3,8 @@ module github.com/marcthe12/cache go 1.24.0 require ( - github.com/rogpeppe/go-internal v1.13.1 - github.com/stretchr/testify v1.10.0 + github.com/rogpeppe/go-internal v1.14.0 github.com/vmihailenco/msgpack/v5 v5.4.1 ) -require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) +require github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect @@ -1,16 +1,14 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/rogpeppe/go-internal v1.14.0 h1:unbRd941gNa8SS77YznHXOYVBDgWcF9xhzECdm8juZc= +github.com/rogpeppe/go-internal v1.14.0/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/pausedtimer/timer.go b/internal/pausedtimer/timer.go index 691c257..6ff212c 100644 --- a/internal/pausedtimer/timer.go +++ b/internal/pausedtimer/timer.go @@ -22,6 +22,7 @@ func New(d time.Duration) *PauseTimer { ret.Ticker = time.NewTicker(math.MaxInt64) ret.Reset(0) } + return ret } @@ -29,6 +30,7 @@ func New(d time.Duration) *PauseTimer { func NewStopped(d time.Duration) *PauseTimer { ret := New(d) ret.Stop() + return ret } diff --git a/internal/pausedtimer/timer_test.go b/internal/pausedtimer/timer_test.go index ac42690..924c07c 100644 --- a/internal/pausedtimer/timer_test.go +++ b/internal/pausedtimer/timer_test.go @@ -3,48 +3,76 @@ package pausedtimer import ( "testing" "time" - - "github.com/stretchr/testify/assert" ) func TestNew(t *testing.T) { d := 1 * time.Second timer := New(d) - assert.Equal(t, d, timer.duration) - assert.NotNil(t, timer.Ticker) + if timer.duration != d { + t.Errorf("expected duration %#v, got %v", d, timer.duration) + } + if timer.Ticker == nil { + t.Error("expected Ticker to be non-nil") + } +} + +func TestPauseTimerZeroDuration(t *testing.T) { + timer := New(0) + if timer.GetDuration() != 0 { + t.Errorf("expected duration %v, got %v", time.Duration(0), timer.GetDuration()) + } + if timer.Ticker == nil { + t.Error("expected Ticker to be non-nil") + } +} + +func TestPauseTimerResetToZero(t *testing.T) { + timer := New(1 * time.Second) + timer.Reset(0) + if timer.GetDuration() != 0 { + t.Errorf("expected duration %v, got %v", time.Duration(0), timer.GetDuration()) + } +} + func TestPauseTimerPauseAndResume(t *testing.T) { - d := 1 * time.Second - timer := New(d) - timer.Stop() // Simulate pause - time.Sleep(500 * time.Millisecond) - timer.Resume() + d := 1 * time.Second + timer := New(d) + timer.Stop() // Simulate pause + time.Sleep(500 * time.Millisecond) + timer.Resume() - select { - case <-timer.C: - // Timer should not have fired yet - t.Fatal("Timer fired too early") - case <-time.After(600 * time.Millisecond): - // Timer should fire after resuming - } + select { + case <-timer.C: + // Timer should not have fired yet + t.Fatal("Timer fired too early") + case <-time.After(600 * time.Millisecond): + // Timer should fire after resuming + } } func TestPauseTimerReset(t *testing.T) { d := 1 * time.Second timer := New(d) - newD := 2 * time.Second - timer.Reset(newD) - assert.Equal(t, newD, timer.duration) + got := 2 * time.Second + timer.Reset(got) + if timer.duration != got { + t.Errorf("expected duration %v, got %v", got, timer.duration) + } } func TestPauseTimerResume(t *testing.T) { d := 1 * time.Second timer := NewStopped(d) timer.Resume() - assert.Equal(t, d, timer.duration) + if timer.duration != d { + t.Errorf("expected duration %v, got %v", d, timer.duration) + } } func TestPauseTimerGetDuration(t *testing.T) { d := 1 * time.Second timer := New(d) - assert.Equal(t, d, timer.GetDuration()) + if timer.GetDuration() != d { + t.Errorf("expected duration %v, got %v", d, timer.GetDuration()) + } } @@ -10,25 +10,27 @@ import ( const initialBucketSize uint64 = 8 - // node represents an entry in the cache with metadata for eviction and expiration. +// node represents an entry in the cache with metadata for eviction and expiration. type node struct { Hash uint64 Expiration time.Time Access uint64 Key []byte Value []byte - HashNext *node - HashPrev *node - EvictNext *node - EvictPrev *node + mu sync.Mutex + + HashNext *node + HashPrev *node + EvictNext *node + EvictPrev *node } - // IsValid checks if the node is still valid based on its expiration time. +// IsValid checks if the node is still valid based on its expiration time. func (n *node) IsValid() bool { return n.Expiration.IsZero() || n.Expiration.After(time.Now()) } - // TTL returns the time-to-live of the node. +// TTL returns the time-to-live of the node. func (n *node) TTL() time.Duration { if n.Expiration.IsZero() { return 0 @@ -37,7 +39,7 @@ func (n *node) TTL() time.Duration { } } - // store represents the in-memory cache with eviction policies and periodic tasks. +// store represents the in-memory cache with eviction policies and periodic tasks. type store struct { Bucket []node Length uint64 @@ -47,19 +49,23 @@ type store struct { SnapshotTicker *pausedtimer.PauseTimer CleanupTicker *pausedtimer.PauseTimer Policy evictionPolicy - mu sync.Mutex + + mu sync.Mutex } - // Init initializes the store with default settings. +// Init initializes the store with default settings. func (s *store) Init() { s.Clear() s.Policy.evict = &s.Evict s.SnapshotTicker = pausedtimer.NewStopped(0) s.CleanupTicker = pausedtimer.NewStopped(10 * time.Second) - s.Policy.SetPolicy(PolicyNone) + + if err := s.Policy.SetPolicy(PolicyNone); err != nil { + panic(err) + } } - // Clear removes all entries from the store. +// Clear removes all entries from the store. func (s *store) Clear() { s.mu.Lock() defer s.mu.Unlock() @@ -72,13 +78,14 @@ func (s *store) Clear() { s.Evict.EvictPrev = &s.Evict } - // lookup calculates the hash and index for a given key. +// lookup calculates the hash and index for a given key. func lookup(s *store, key []byte) (uint64, uint64) { hash := hash(key) + return hash % uint64(len(s.Bucket)), hash } - // lazyInitBucket initializes the hash bucket if it hasn't been initialized yet. +// lazyInitBucket initializes the hash bucket if it hasn't been initialized yet. func lazyInitBucket(n *node) { if n.HashNext == nil { n.HashNext = n @@ -86,7 +93,7 @@ func lazyInitBucket(n *node) { } } - // lookup finds a node in the store by key. +// lookup finds a node in the store by key. func (s *store) lookup(key []byte) (*node, uint64, uint64) { idx, hash := lookup(s, key) @@ -103,22 +110,25 @@ func (s *store) lookup(key []byte) (*node, uint64, uint64) { return nil, idx, hash } - // get retrieves a value from the store by key. +// get retrieves a value from the store by key. func (s *store) get(key []byte) ([]byte, time.Duration, bool) { v, _, _ := s.lookup(key) if v != nil { if !v.IsValid() { deleteNode(s, v) + return nil, 0, false } + s.Policy.OnAccess(v) + return v.Value, v.TTL(), true } return nil, 0, false } - // Get retrieves a value from the store by key with locking. +// Get retrieves a value from the store by key with locking. func (s *store) Get(key []byte) ([]byte, time.Duration, bool) { s.mu.Lock() defer s.mu.Unlock() @@ -126,7 +136,7 @@ func (s *store) Get(key []byte) ([]byte, time.Duration, bool) { return s.get(key) } - // resize doubles the size of the hash table and rehashes all entries. +// resize doubles the size of the hash table and rehashes all entries. func resize(s *store) { bucket := make([]node, 2*len(s.Bucket)) @@ -149,7 +159,7 @@ func resize(s *store) { s.Bucket = bucket } - // cleanup removes expired entries from the store. +// cleanup removes expired entries from the store. func cleanup(s *store) { for v := s.Evict.EvictNext; v != &s.Evict; v = v.EvictNext { if !v.IsValid() { @@ -158,19 +168,21 @@ func cleanup(s *store) { } } - // evict removes entries from the store based on the eviction policy. +// evict removes entries from the store based on the eviction policy. func evict(s *store) bool { for s.MaxCost != 0 && s.MaxCost < s.Cost { n := s.Policy.Evict() if n == nil { break } + deleteNode(s, n) } + return true } - // set adds or updates a key-value pair in the store. +// set adds or updates a key-value pair in the store. func (s *store) set(key []byte, value []byte, ttl time.Duration) { v, idx, hash := s.lookup(key) if v != nil { @@ -181,9 +193,10 @@ func (s *store) set(key []byte, value []byte, ttl time.Duration) { } bucket := &s.Bucket[idx] + if float64(s.Length)/float64(len(s.Bucket)) > 0.75 { resize(s) - //resize may invidate pointer to bucket + // resize may invidate pointer to bucket _, idx, _ := s.lookup(key) bucket = &s.Bucket[idx] lazyInitBucket(bucket) @@ -210,7 +223,7 @@ func (s *store) set(key []byte, value []byte, ttl time.Duration) { s.Length = s.Length + 1 } - // Set adds or updates a key-value pair in the store with locking. +// Set adds or updates a key-value pair in the store with locking. func (s *store) Set(key []byte, value []byte, ttl time.Duration) { s.mu.Lock() defer s.mu.Unlock() @@ -218,7 +231,7 @@ func (s *store) Set(key []byte, value []byte, ttl time.Duration) { s.set(key, value, ttl) } - // deleteNode removes a node from the store. +// deleteNode removes a node from the store. func deleteNode(s *store, v *node) { v.HashNext.HashPrev = v.HashPrev v.HashPrev.HashNext = v.HashNext @@ -234,18 +247,19 @@ func deleteNode(s *store, v *node) { s.Length = s.Length - 1 } - // delete removes a key-value pair from the store. +// delete removes a key-value pair from the store. func (s *store) delete(key []byte) bool { v, _, _ := s.lookup(key) if v != nil { deleteNode(s, v) + return true } return false } - // Delete removes a key-value pair from the store with locking. +// Delete removes a key-value pair from the store with locking. func (s *store) Delete(key []byte) bool { s.mu.Lock() defer s.mu.Unlock() diff --git a/store_test.go b/store_test.go index 9b802bf..89f4de6 100644 --- a/store_test.go +++ b/store_test.go @@ -1,19 +1,19 @@ package cache import ( + "bytes" "encoding/binary" - "fmt" + "strconv" "testing" "time" - - "github.com/stretchr/testify/assert" ) -func setupTestStore(t testing.TB) *store { - t.Helper() +func setupTestStore(tb testing.TB) *store { + tb.Helper() store := &store{} store.Init() + return store } @@ -28,11 +28,16 @@ func TestStoreGetSet(t *testing.T) { want := []byte("Value") store.Set([]byte("Key"), want, 1*time.Hour) got, ttl, ok := store.Get([]byte("Key")) - assert.Equal(t, want, got) + if !ok { + t.Errorf("expected key to exist") + } + if !bytes.Equal(want, got) { + t.Errorf("got %v, want %v", got, want) + } + if ttl.Round(time.Second) != 1*time.Hour { + t.Errorf("ttl same: got %v expected %v", ttl.Round(time.Second), 1*time.Hour) + } - now := time.Now() - assert.WithinDuration(t, now.Add(ttl), now.Add(1*time.Hour), 1*time.Millisecond) - assert.True(t, ok) }) t.Run("Exists TTL", func(t *testing.T) { @@ -42,17 +47,18 @@ func TestStoreGetSet(t *testing.T) { want := []byte("Value") store.Set([]byte("Key"), want, time.Nanosecond) - _, _, ok := store.Get([]byte("Key")) - assert.False(t, ok) + if _, _, ok := store.Get([]byte("Key")); ok { + t.Errorf("expected key to not exist") + } }) t.Run("Not Exists", func(t *testing.T) { t.Parallel() store := setupTestStore(t) - - _, _, ok := store.Get([]byte("Key")) - assert.False(t, ok) + if _, _, ok := store.Get([]byte("Key")); ok { + t.Errorf("expected key to not exist") + } }) t.Run("Update", func(t *testing.T) { @@ -61,11 +67,16 @@ func TestStoreGetSet(t *testing.T) { store := setupTestStore(t) store.Set([]byte("Key"), []byte("Other"), 0) + want := []byte("Value") store.Set([]byte("Key"), want, 0) got, _, ok := store.Get([]byte("Key")) - assert.Equal(t, want, got) - assert.True(t, ok) + if !bytes.Equal(want, got) { + t.Errorf("got %v, want %v", got, want) + } + if !ok { + t.Errorf("expected key to exist") + } }) t.Run("Resize", func(t *testing.T) { @@ -80,16 +91,20 @@ func TestStoreGetSet(t *testing.T) { for i := range store.Length { key := binary.LittleEndian.AppendUint64(nil, i) - _, _, ok := store.Get(key) - assert.True(t, ok, i) + if _, _, ok := store.Get(key); !ok { + t.Errorf("expected key %v to exist", i) + } } - assert.Len(t, store.Bucket, int(initialBucketSize)*2) + if len(store.Bucket) != int(initialBucketSize)*2 { + t.Errorf("expected bucket size to be %v, got %v", initialBucketSize*2, len(store.Bucket)) + } for i := range store.Length { key := binary.LittleEndian.AppendUint64(nil, i) - _, _, ok := store.Get(key) - assert.True(t, ok, i) + if _, _, ok := store.Get(key); !ok { + t.Errorf("expected key %d to exist", i) + } } }) } @@ -104,10 +119,13 @@ func TestStoreDelete(t *testing.T) { want := []byte("Value") store.Set([]byte("Key"), want, 0) - ok := store.Delete([]byte("Key")) - assert.True(t, ok) - _, _, ok = store.Get([]byte("Key")) - assert.False(t, ok) + + if !store.Delete([]byte("Key")) { + t.Errorf("expected key to be deleted") + } + if _, _, ok := store.Get([]byte("Key")); ok { + t.Errorf("expected key to not exist") + } }) t.Run("Not Exists", func(t *testing.T) { @@ -115,8 +133,9 @@ func TestStoreDelete(t *testing.T) { store := setupTestStore(t) - ok := store.Delete([]byte("Key")) - assert.False(t, ok) + if store.Delete([]byte("Key")) { + t.Errorf("expected key to not exist") + } }) } @@ -128,25 +147,29 @@ func TestStoreClear(t *testing.T) { want := []byte("Value") store.Set([]byte("Key"), want, 0) store.Clear() - _, _, ok := store.Get([]byte("Key")) - assert.False(t, ok) + if _, _, ok := store.Get([]byte("Key")); ok { + t.Errorf("expected key to not exist") + } } func BenchmarkStoreGet(b *testing.B) { for n := 1; n <= 10000; n *= 10 { - b.Run(fmt.Sprint(n), func(b *testing.B) { + b.Run(strconv.Itoa(n), func(b *testing.B) { want := setupTestStore(b) + for i := range n - 1 { buf := make([]byte, 8) binary.LittleEndian.PutUint64(buf, uint64(i)) want.Set(buf, buf, 0) } + key := []byte("Key") want.Set(key, []byte("Store"), 0) b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + + for b.Loop() { want.Get(key) } }) @@ -155,18 +178,22 @@ func BenchmarkStoreGet(b *testing.B) { func BenchmarkStoreSet(b *testing.B) { for n := 1; n <= 10000; n *= 10 { - b.Run(fmt.Sprint(n), func(b *testing.B) { + b.Run(strconv.Itoa(n), func(b *testing.B) { want := setupTestStore(b) + for i := range n - 1 { buf := make([]byte, 8) binary.LittleEndian.PutUint64(buf, uint64(i)) want.Set(buf, buf, 0) } + key := []byte("Key") store := []byte("Store") + b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + + for b.Loop() { want.Set(key, store, 0) } }) @@ -175,18 +202,22 @@ func BenchmarkStoreSet(b *testing.B) { func BenchmarkStoreDelete(b *testing.B) { for n := 1; n <= 10000; n *= 10 { - b.Run(fmt.Sprint(n), func(b *testing.B) { + b.Run(strconv.Itoa(n), func(b *testing.B) { want := setupTestStore(b) + for i := range n - 1 { buf := make([]byte, 8) binary.LittleEndian.PutUint64(buf, uint64(i)) want.Set(buf, buf, 0) } + key := []byte("Key") store := []byte("Store") + b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + + for b.Loop() { want.Set(key, store, 0) want.Delete(key) } @@ -7,6 +7,7 @@ import ( // zero returns the zero value for the specified type. func zero[T any]() T { var ret T + return ret } @@ -16,5 +17,6 @@ func hash(data []byte) uint64 { if _, err := hasher.Write(data); err != nil { panic(err) } + return hasher.Sum64() } |