diff options
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | conn.go | 78 | ||||
-rw-r--r-- | conn_test.go | 28 | ||||
-rw-r--r-- | encoding.go | 8 | ||||
-rw-r--r-- | encoding_test.go | 140 | ||||
-rw-r--r-- | evict.go | 107 | ||||
-rw-r--r-- | store.go | 128 | ||||
-rw-r--r-- | store_test.go | 388 |
8 files changed, 640 insertions, 239 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fde9c9e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.out +vendor/ @@ -3,6 +3,7 @@ package cache import ( "errors" "io" + "log" "os" "sync" "time" @@ -58,7 +59,9 @@ func openFile(filename string, options ...Option) (*db, error) { func openMem(options ...Option) (*db, error) { ret := &db{} ret.Store.Init() - ret.SetConfig(options...) + if err := ret.SetConfig(options...); err != nil { + return nil, err + } return ret, nil } @@ -123,6 +126,12 @@ func SetCleanupTime(t time.Duration) Option { func (d *db) backgroundWorker() { defer d.wg.Done() + defer func() { + if r := recover(); r != nil { + log.Printf("Recovered from panic in background worker: %v", r) + } + }() + d.Store.SnapshotTicker.Resume() defer d.Store.SnapshotTicker.Stop() @@ -136,25 +145,30 @@ func (d *db) backgroundWorker() { case <-d.Store.SnapshotTicker.C: d.Flush() case <-d.Store.CleanupTicker.C: - cleanup(&d.Store) - evict(&d.Store) + d.Store.Cleanup() + d.Store.Evict() } } } // Close stops the background worker and cleans up resources. -func (d *db) Close() { +func (d *db) Close() error { close(d.Stop) d.wg.Wait() - d.Flush() + err := d.Flush() d.Clear() + var err1 error if d.File != nil { closer, ok := d.File.(io.Closer) if ok { - closer.Close() + err1 = closer.Close() } } + if err != nil { + return err + } + return err1 } // Flush writes the current state of the store to the file. @@ -273,3 +287,55 @@ func (h *DB[K, V]) Delete(key K) error { return nil } + +// UpdateInPlace retrieves a value from the cache, processes it using the provided function, +// and then sets the result back into the cache with the same key. +func (h *DB[K, V]) UpdateInPlace(key K, processFunc func(V) (V, error), ttl time.Duration) error { + keyData, err := marshal(key) + if err != nil { + return err + } + + return h.Store.UpdateInPlace(keyData, func(data []byte) ([]byte, error) { + var value V + if err := unmarshal(data, &value); err != nil { + return nil, err + } + + processedValue, err := processFunc(value) + if err != nil { + return nil, err + } + + return marshal(processedValue) + }, ttl) +} + +// Memoize attempts to retrieve a value from the cache. If the retrieval fails, +// it sets the result of the factory function into the cache and returns that result. +func (h *DB[K, V]) Memoize(key K, factoryFunc func() (V, error), ttl time.Duration) (V, error) { + keyData, err := marshal(key) + if err != nil { + return zero[V](), err + } + + data, err := h.Store.Memoize(keyData, func() ([]byte, error) { + value, err := factoryFunc() + if err != nil { + return nil, err + } + + return marshal(value) + }, ttl) + + if err != nil { + return zero[V](), err + } + + var value V + if err := unmarshal(data, &value); err != nil { + return zero[V](), err + } + + return value, nil +} diff --git a/conn_test.go b/conn_test.go index a97a5e9..ee2c175 100644 --- a/conn_test.go +++ b/conn_test.go @@ -136,7 +136,9 @@ func BenchmarkDBGet(b *testing.B) { b.Run(strconv.Itoa(n), func(b *testing.B) { db := setupTestDB[int, int](b) for i := range n { - db.Set(i, i, 0) + if err := db.Set(i, i, 0); err != nil { + b.Fatalf("unexpected error: %v", err) + } } b.ReportAllocs() @@ -144,7 +146,9 @@ func BenchmarkDBGet(b *testing.B) { b.ResetTimer() for b.Loop() { - db.GetValue(n - 1) + if _, _, err := db.GetValue(n - 1); err != nil { + b.Fatalf("unexpected error: %v", err) + } } }) } @@ -155,14 +159,18 @@ func BenchmarkDBSet(b *testing.B) { b.Run(strconv.Itoa(n), func(b *testing.B) { db := setupTestDB[int, int](b) for i := range n - 1 { - db.Set(i, i, 0) + if err := db.Set(i, i, 0); err != nil { + b.Fatalf("unexpected error: %v", err) + } } b.ReportAllocs() b.ResetTimer() for b.Loop() { - db.Set(n, n, 0) + if err := db.Set(n, n, 0); err != nil { + b.Fatalf("unexpected error: %v", err) + } } }) } @@ -173,15 +181,21 @@ func BenchmarkDBDelete(b *testing.B) { b.Run(strconv.Itoa(n), func(b *testing.B) { db := setupTestDB[int, int](b) for i := range n - 1 { - db.Set(i, i, 0) + if err := db.Set(i, i, 0); err != nil { + b.Fatalf("unexpected error: %v", err) + } } b.ReportAllocs() b.ResetTimer() for b.Loop() { - db.Set(n, n, 0) - db.Delete(n) + if err := db.Set(n, n, 0); err != nil { + b.Fatalf("unexpected error: %v", err) + } + if err := db.Delete(n); err != nil { + b.Fatalf("unexpected error: %v", err) + } } }) } diff --git a/encoding.go b/encoding.go index 1327087..a4af699 100644 --- a/encoding.go +++ b/encoding.go @@ -81,7 +81,7 @@ func (e *encoder) EncodeStore(s *store) error { return err } - for v := s.Evict.EvictNext; v != &s.Evict; v = v.EvictNext { + for v := s.EvictList.EvictNext; v != &s.EvictList; v = v.EvictNext { if err := e.EncodeNode(v); err != nil { return err } @@ -186,7 +186,9 @@ func (d *decoder) DecodeStore(s *store) error { return err } - s.Policy.SetPolicy(EvictionPolicyType(policy)) + if err := s.Policy.SetPolicy(EvictionPolicyType(policy)); err != nil { + return err + } length, err := d.DecodeUint64() if err != nil { @@ -217,7 +219,7 @@ func (d *decoder) DecodeStore(s *store) error { v.HashNext.HashPrev = v v.HashPrev.HashNext = v - v.EvictNext = &s.Evict + v.EvictNext = &s.EvictList v.EvictPrev = v.EvictNext.EvictPrev v.EvictNext.EvictPrev = v v.EvictPrev.EvictNext = v diff --git a/encoding_test.go b/encoding_test.go index 4f3af4f..0f9245e 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -3,7 +3,6 @@ package cache import ( "bytes" "encoding/binary" - "fmt" "os" "strconv" "testing" @@ -36,91 +35,17 @@ func TestEncodeDecodeUint64(t *testing.T) { var buf bytes.Buffer e := newEncoder(&buf) - err := e.EncodeUint64(tt.value) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - err = e.Flush() - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - decoder := newDecoder(bytes.NewReader(buf.Bytes())) - - decodedValue, err := decoder.DecodeUint64() - if err != nil { + if err := e.EncodeUint64(tt.value); err != nil { t.Errorf("unexpected error: %v", err) } - if tt.value != decodedValue { - t.Errorf("expected %v, got %v", tt.value, decodedValue) - } - }) - } -} - -func TestEncodeDecodeStoreWithPolicies(t *testing.T) { - policies := []EvictionPolicyType{PolicyFIFO, PolicyLRU, PolicyLFU, PolicyLTR} - - for _, policy := range policies { - t.Run(fmt.Sprintf("Policy_%d", policy), func(t *testing.T) { - var buf bytes.Buffer - e := newEncoder(&buf) - - store := setupTestStore(t) - store.Policy.SetPolicy(policy) - - err := e.EncodeStore(store) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - err = e.Flush() - if err != nil { + if err := e.Flush(); err != nil { t.Errorf("unexpected error: %v", err) } decoder := newDecoder(bytes.NewReader(buf.Bytes())) - decodedStore := setupTestStore(t) - - err = decoder.DecodeStore(decodedStore) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - - if store.Policy.Type != decodedStore.Policy.Type { - t.Errorf("expected %v, got %v", store.Policy.Type, decodedStore.Policy.Type) - } - }) - } -} -func TestEncodeDecodeTimeBoundary(t *testing.T) { - tests := []struct { - name string - value time.Time - }{ - {name: "Unix Epoch", value: time.Unix(0, 0)}, - {name: "Far Future", value: time.Unix(1<<63-1, 0)}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - e := newEncoder(&buf) - - err := e.EncodeTime(tt.value) - - if err != nil { - t.Errorf("unexpected error: %v", err) - } - if err != nil { - t.Errorf("unexpected error: %v", err) - } - err = e.Flush() - - decoder := newDecoder(bytes.NewReader(buf.Bytes())) - - decodedValue, err := decoder.DecodeTime() + decodedValue, err := decoder.DecodeUint64() if err != nil { t.Errorf("unexpected error: %v", err) } @@ -147,12 +72,10 @@ func TestEncodeDecodeTime(t *testing.T) { var buf bytes.Buffer e := newEncoder(&buf) - err := e.EncodeTime(tt.value) - if err != nil { + if err := e.EncodeTime(tt.value); err != nil { t.Fatalf("unexpected error: %v", err) } - err = e.Flush() - if err != nil { + if err := e.Flush(); err != nil { t.Fatalf("unexpected error: %v", err) } @@ -175,19 +98,16 @@ func TestDecodeBytesError(t *testing.T) { var buf bytes.Buffer e := newEncoder(&buf) - err := e.EncodeBytes([]byte("DEADBEEF")) - if err != nil { + if err := e.EncodeBytes([]byte("DEADBEEF")); err != nil { t.Errorf("unexpected error: %v", err) } - err = e.Flush() - if err != nil { + if err := e.Flush(); err != nil { t.Errorf("unexpected error: %v", err) } decoder := newDecoder(bytes.NewReader(buf.Bytes()[:10])) - _, err = decoder.DecodeBytes() - if err == nil { + if _, err := decoder.DecodeBytes(); err == nil { t.Errorf("expected an error but got none") } } @@ -207,12 +127,10 @@ func TestEncodeDecodeBytes(t *testing.T) { var buf bytes.Buffer e := newEncoder(&buf) - err := e.EncodeBytes(tt.value) - if err != nil { + if err := e.EncodeBytes(tt.value); err != nil { t.Errorf("unexpected error: %v", err) } - err = e.Flush() - if err != nil { + if err := e.Flush(); err != nil { t.Errorf("unexpected error: %v", err) } @@ -272,12 +190,11 @@ func TestEncodeDecodeNode(t *testing.T) { var buf bytes.Buffer e := newEncoder(&buf) - err := e.EncodeNode(tt.value) - if err != nil { + if err := e.EncodeNode(tt.value); err != nil { t.Errorf("unexpected error: %v", err) } - err = e.Flush() - if err != nil { + + if err := e.Flush(); err != nil { t.Errorf("unexpected error: %v", err) } @@ -348,8 +265,8 @@ func TestEncodeDecodeStrorage(t *testing.T) { want := setupTestStore(t) want.MaxCost = uint64(tt.maxCost) - err := want.Policy.SetPolicy(tt.policy) - if err != nil { + + if err := want.Policy.SetPolicy(tt.policy); err != nil { t.Errorf("unexpected error: %v", err) } @@ -357,20 +274,17 @@ func TestEncodeDecodeStrorage(t *testing.T) { want.Set([]byte(k), []byte(v), 0) } - err = e.EncodeStore(want) - if err != nil { + if err := e.EncodeStore(want); err != nil { t.Errorf("unexpected error: %v", err) } - err = e.Flush() - if err != nil { + if err := e.Flush(); err != nil { t.Errorf("unexpected error: %v", err) } decoder := newDecoder(bytes.NewReader(buf.Bytes())) got := setupTestStore(t) - err = decoder.DecodeStore(got) - if err != nil { + if err := decoder.DecodeStore(got); err != nil { t.Errorf("unexpected error: %v", err) } @@ -384,8 +298,8 @@ func TestEncodeDecodeStrorage(t *testing.T) { t.Errorf("expected %v, got %v", want.Policy.Type, got.Policy.Type) } - gotOrder := getListOrder(t, &got.Evict) - for i, v := range getListOrder(t, &want.Evict) { + gotOrder := getListOrder(t, &got.EvictList) + for i, v := range getListOrder(t, &want.EvictList) { if !bytes.Equal(v.Key, gotOrder[i].Key) { t.Errorf("expected %#v, got %#v", v.Key, gotOrder[i].Key) } @@ -423,8 +337,7 @@ func BenchmarkEncoder_EncodeStore(b *testing.B) { want.Set(buf, buf, 0) } - err = want.Snapshot(file) - if err != nil { + if err = want.Snapshot(file); err != nil { b.Fatalf("unexpected error: %v", err) } @@ -438,7 +351,9 @@ func BenchmarkEncoder_EncodeStore(b *testing.B) { b.ResetTimer() for b.Loop() { - want.Snapshot(file) + if err := want.Snapshot(file); err != nil { + b.Fatalf("unexpected error: %v", err) + } } }) } @@ -464,8 +379,7 @@ func BenchmarkDecoder_DecodeStore(b *testing.B) { want.Set(buf, buf, 0) } - err = want.Snapshot(file) - if err != nil { + if err = want.Snapshot(file); err != nil { b.Fatalf("unexpected error: %v", err) } fileInfo, err := file.Stat() @@ -478,7 +392,9 @@ func BenchmarkDecoder_DecodeStore(b *testing.B) { b.ResetTimer() for b.Loop() { - want.LoadSnapshot(file) + if err := want.LoadSnapshot(file); err != nil { + b.Fatalf("unexpected error: %v", err) + } } }) } @@ -80,16 +80,18 @@ type fifoPolicy struct { } // OnInsert adds a node to the eviction list. -func (s fifoPolicy) OnInsert(node *node) { - pushEvict(node, s.evict) +func (s fifoPolicy) OnInsert(n *node) { + pushEvict(n, s.evict) } // OnAccess is a no-op for fifoPolicy. func (fifoPolicy) OnAccess(n *node) { + // Noop } // OnUpdate is a no-op for fifoPolicy. func (fifoPolicy) OnUpdate(n *node) { + // Noop } // Evict returns the oldest node for fifoPolicy. @@ -111,20 +113,20 @@ type lruPolicy struct { } // OnInsert adds a node to the eviction list. -func (s lruPolicy) OnInsert(node *node) { - pushEvict(node, s.evict) +func (s lruPolicy) OnInsert(n *node) { + pushEvict(n, s.evict) } // OnUpdate moves the accessed node to the front of the eviction list. -func (s lruPolicy) OnUpdate(node *node) { - s.OnAccess(node) +func (s lruPolicy) OnUpdate(n *node) { + s.OnAccess(n) } // OnAccess moves the accessed node to the front of the eviction list. -func (s lruPolicy) OnAccess(node *node) { - node.EvictNext.EvictPrev = node.EvictPrev - node.EvictPrev.EvictNext = node.EvictNext - s.OnInsert(node) +func (s lruPolicy) OnAccess(n *node) { + n.EvictNext.EvictPrev = n.EvictPrev + n.EvictPrev.EvictNext = n.EvictNext + s.OnInsert(n) } // Evict returns the least recently used node for lruPolicy. @@ -146,40 +148,40 @@ type lfuPolicy struct { } // OnInsert adds a node to the eviction list and initializes its access count. -func (s lfuPolicy) OnInsert(node *node) { - pushEvict(node, s.evict) +func (s lfuPolicy) OnInsert(n *node) { + pushEvict(n, s.evict) } // OnUpdate increments the access count of the node and reorders the list. -func (s lfuPolicy) OnUpdate(node *node) { - s.OnAccess(node) +func (s lfuPolicy) OnUpdate(n *node) { + s.OnAccess(n) } // OnAccess increments the access count of the node and reorders the list. -func (s lfuPolicy) OnAccess(node *node) { - node.Access++ +func (s lfuPolicy) OnAccess(n *node) { + n.Access++ - for v := node.EvictPrev; v.EvictPrev != s.evict; v = v.EvictPrev { - if v.Access <= node.Access { - node.EvictNext.EvictPrev = node.EvictPrev - node.EvictPrev.EvictNext = node.EvictNext + for v := n.EvictPrev; v.EvictPrev != s.evict; v = v.EvictPrev { + if v.Access <= n.Access { + n.EvictNext.EvictPrev = n.EvictPrev + n.EvictPrev.EvictNext = n.EvictNext - node.EvictPrev = v - node.EvictNext = node.EvictPrev.EvictNext - node.EvictNext.EvictPrev = node - node.EvictPrev.EvictNext = node + n.EvictPrev = v + n.EvictNext = n.EvictPrev.EvictNext + n.EvictNext.EvictPrev = n + n.EvictPrev.EvictNext = n return } } - node.EvictNext.EvictPrev = node.EvictPrev - node.EvictPrev.EvictNext = node.EvictNext + n.EvictNext.EvictPrev = n.EvictPrev + n.EvictPrev.EvictNext = n.EvictNext - node.EvictPrev = s.evict - node.EvictNext = node.EvictPrev.EvictNext - node.EvictNext.EvictPrev = node - node.EvictPrev.EvictNext = node + n.EvictPrev = s.evict + n.EvictNext = n.EvictPrev.EvictNext + n.EvictNext.EvictPrev = n + n.EvictPrev.EvictNext = n } // Evict returns the least frequently used node for LFU. @@ -203,55 +205,56 @@ type ltrPolicy struct { // OnInsert adds a node to the eviction list based on its TTL (Time To Live). // It places the node in the correct position in the list based on TTL. -func (s ltrPolicy) OnInsert(node *node) { - pushEvict(node, s.evict) +func (s ltrPolicy) OnInsert(n *node) { + pushEvict(n, s.evict) - s.OnUpdate(node) + s.OnUpdate(n) } // OnAccess is a no-op for ltrPolicy. // It does not perform any action when a node is accessed. -func (s ltrPolicy) OnAccess(node *node) { +func (s ltrPolicy) OnAccess(n *node) { + // Noop } // OnUpdate updates the position of the node in the eviction list based on its TTL. // It reorders the list to maintain the correct order based on TTL. -func (s ltrPolicy) OnUpdate(node *node) { - if node.TTL() == 0 { +func (s ltrPolicy) OnUpdate(n *node) { + if n.TTL() == 0 { return } - for v := node.EvictPrev; v.EvictPrev != s.evict; v = v.EvictPrev { + for v := n.EvictPrev; v.EvictPrev != s.evict; v = v.EvictPrev { if v.TTL() == 0 { continue } - if v.TTL() < node.TTL() { - node.EvictNext.EvictPrev = node.EvictPrev - node.EvictPrev.EvictNext = node.EvictNext + if v.TTL() < n.TTL() { + n.EvictNext.EvictPrev = n.EvictPrev + n.EvictPrev.EvictNext = n.EvictNext - node.EvictPrev = v - node.EvictNext = node.EvictPrev.EvictNext - node.EvictNext.EvictPrev = node - node.EvictPrev.EvictNext = node + n.EvictPrev = v + n.EvictNext = n.EvictPrev.EvictNext + n.EvictNext.EvictPrev = n + n.EvictPrev.EvictNext = n return } } - for v := node.EvictNext; v.EvictNext != s.evict; v = v.EvictNext { + for v := n.EvictNext; v.EvictNext != s.evict; v = v.EvictNext { if v.TTL() == 0 { continue } - if v.TTL() > node.TTL() { - node.EvictNext.EvictPrev = node.EvictPrev - node.EvictPrev.EvictNext = node.EvictNext + if v.TTL() > n.TTL() { + n.EvictNext.EvictPrev = n.EvictPrev + n.EvictPrev.EvictNext = n.EvictNext - node.EvictPrev = v - node.EvictNext = node.EvictPrev.EvictNext - node.EvictNext.EvictPrev = node - node.EvictPrev.EvictNext = node + n.EvictPrev = v + n.EvictNext = n.EvictPrev.EvictNext + n.EvictNext.EvictPrev = n + n.EvictPrev.EvictNext = n return } @@ -77,8 +77,8 @@ func (s *store) Clear() { s.EvictList.EvictPrev = &s.EvictList } -// lookup calculates the hash and index for a given key. -func lookup(s *store, key []byte) (uint64, uint64) { +// lookupIdx calculates the hash and index for a given key. +func lookupIdx(s *store, key []byte) (uint64, uint64) { hash := hash(key) return hash % uint64(len(s.Bucket)), hash @@ -94,7 +94,7 @@ func lazyInitBucket(n *node) { // lookup finds a node in the store by key. func (s *store) lookup(key []byte) (*node, uint64, uint64) { - idx, hash := lookup(s, key) + idx, hash := lookupIdx(s, key) bucket := &s.Bucket[idx] @@ -154,10 +154,12 @@ func (s *store) Cleanup() { s.mu.Lock() defer s.mu.Unlock() - for v := s.EvictList.EvictNext; v != &s.EvictList; v = v.EvictNext { + for v := s.EvictList.EvictNext; v != &s.EvictList; { + n := v.EvictNext if !v.IsValid() { deleteNode(s, v) } + v = n } } @@ -166,62 +168,78 @@ func (s *store) Evict() bool { s.mu.Lock() defer s.mu.Unlock() - for s.MaxCost != 0 && s.MaxCost < s.Cost { + if s.MaxCost == 0 { + return true + } + + for s.MaxCost < s.Cost { n := s.Policy.Evict() if n == nil { break } - deleteNode(s, n) } return true } -// Set adds or updates a key-value pair in the store with locking. -func (s *store) Set(key []byte, value []byte, ttl time.Duration) { - s.mu.Lock() - defer s.mu.Unlock() - - v, idx, hash := s.lookup(key) - if v != nil { - s.Cost = s.Cost + uint64(len(value)) - uint64(len(v.Value)) - v.Value = value - v.Expiration = time.Now().Add(ttl) - s.Policy.OnUpdate(v) - } - +// insert adds a new key-value pair to the store. +func (s *store) insert(key []byte, value []byte, ttl time.Duration) { + idx, hash := lookupIdx(s, key) bucket := &s.Bucket[idx] if float64(s.Length)/float64(len(s.Bucket)) > 0.75 { s.Resize() - // resize may invidate pointer to bucket - _, idx, _ := s.lookup(key) + // resize may invalidate pointer to bucket + _, idx, _ = s.lookup(key) bucket = &s.Bucket[idx] lazyInitBucket(bucket) } - node := &node{ + v := &node{ Hash: hash, Key: key, Value: value, } if ttl != 0 { - node.Expiration = time.Now().Add(ttl) + v.Expiration = time.Now().Add(ttl) + } else { + v.Expiration = zero[time.Time]() } - node.HashPrev = bucket - node.HashNext = node.HashPrev.HashNext - node.HashNext.HashPrev = node - node.HashPrev.HashNext = node + v.HashPrev = bucket + v.HashNext = v.HashPrev.HashNext + v.HashNext.HashPrev = v + v.HashPrev.HashNext = v - s.Policy.OnInsert(node) + s.Policy.OnInsert(v) s.Cost = s.Cost + uint64(len(key)) + uint64(len(value)) s.Length = s.Length + 1 } +// Set adds or updates a key-value pair in the store with locking. +func (s *store) Set(key []byte, value []byte, ttl time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + + v, _, _ := s.lookup(key) + if v != nil { + s.Cost = s.Cost + uint64(len(value)) - uint64(len(v.Value)) + v.Value = value + if ttl != 0 { + v.Expiration = time.Now().Add(ttl) + } else { + v.Expiration = zero[time.Time]() + } + s.Policy.OnUpdate(v) + return + } + + s.insert(key, value, ttl) +} + // deleteNode removes a node from the store. func deleteNode(s *store, v *node) { v.HashNext.HashPrev = v.HashPrev @@ -252,3 +270,57 @@ func (s *store) Delete(key []byte) bool { return false } + +// UpdateInPlace retrieves a value from the store, processes it using the provided function, +// and then sets the result back into the store with the same key. +func (s *store) UpdateInPlace(key []byte, processFunc func([]byte) ([]byte, error), ttl time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + + v, _, _ := s.lookup(key) + if v == nil { + return ErrKeyNotFound + } + + if !v.IsValid() { + deleteNode(s, v) + return ErrKeyNotFound + } + + processedValue, err := processFunc(v.Value) + if err != nil { + return err + } + + s.Cost = s.Cost + uint64(len(processedValue)) - uint64(len(v.Value)) + v.Value = processedValue + if ttl != 0 { + v.Expiration = time.Now().Add(ttl) + } else { + v.Expiration = zero[time.Time]() + } + s.Policy.OnUpdate(v) + + return nil +} + +// Memorize attempts to retrieve a value from the store. If the retrieval fails, +// it sets the result of the factory function into the store and returns that result. +func (s *store) Memorize(key []byte, factoryFunc func() ([]byte, error), ttl time.Duration) ([]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + + v, _, _ := s.lookup(key) + if v != nil && v.IsValid() { + s.Policy.OnAccess(v) + return v.Value, nil + } + + factoryValue, err := factoryFunc() + if err != nil { + return nil, err + } + + s.insert(key, factoryValue, ttl) + return factoryValue, nil +} diff --git a/store_test.go b/store_test.go index 89f4de6..e2c44e9 100644 --- a/store_test.go +++ b/store_test.go @@ -3,6 +3,7 @@ package cache import ( "bytes" "encoding/binary" + "errors" "strconv" "testing" "time" @@ -26,10 +27,30 @@ func TestStoreGetSet(t *testing.T) { store := setupTestStore(t) want := []byte("Value") + store.Set([]byte("Key"), want, 0) + got, ttl, ok := store.Get([]byte("Key")) + if !ok { + t.Fatalf("expected key to exist") + } + if !bytes.Equal(want, got) { + t.Errorf("got %v, want %v", got, want) + } + if ttl.Round(time.Second) != 0 { + t.Errorf("ttl same: got %v expected %v", ttl.Round(time.Second), 1*time.Hour) + } + + }) + + t.Run("Exists Non Expiry", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + + want := []byte("Value") store.Set([]byte("Key"), want, 1*time.Hour) got, ttl, ok := store.Get([]byte("Key")) if !ok { - t.Errorf("expected key to exist") + t.Fatalf("expected key to exist") } if !bytes.Equal(want, got) { t.Errorf("got %v, want %v", got, want) @@ -71,12 +92,13 @@ func TestStoreGetSet(t *testing.T) { want := []byte("Value") store.Set([]byte("Key"), want, 0) got, _, ok := store.Get([]byte("Key")) + if !ok { + t.Fatal("expected key to exist") + } if !bytes.Equal(want, got) { t.Errorf("got %v, want %v", got, want) } - if !ok { - t.Errorf("expected key to exist") - } + }) t.Run("Resize", func(t *testing.T) { @@ -152,49 +174,353 @@ func TestStoreClear(t *testing.T) { } } +func TestStoreUpdateInPlace(t *testing.T) { + t.Parallel() + + t.Run("Exists", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + + want := []byte("Value") + store.Set([]byte("Key"), []byte("Initial"), 1*time.Hour) + + processFunc := func(v []byte) ([]byte, error) { + return want, nil + } + + if err := store.UpdateInPlace([]byte("Key"), processFunc, 1*time.Hour); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + got, _, ok := store.Get([]byte("Key")) + if !ok { + t.Fatalf("expected key to exist") + } + if !bytes.Equal(want, got) { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("Not Exists", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + + processFunc := func(v []byte) ([]byte, error) { + return []byte("Value"), nil + } + + if err := store.UpdateInPlace([]byte("Key"), processFunc, 1*time.Hour); !errors.Is(err, ErrKeyNotFound) { + t.Fatalf("expected error: %v, got: %v", ErrKeyNotFound, err) + } + }) +} + +func TestStoreMemoize(t *testing.T) { + t.Parallel() + + t.Run("Cache Miss", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + + factoryFunc := func() ([]byte, error) { + return []byte("Value"), nil + } + + got, err := store.Memoize([]byte("Key"), factoryFunc, 1*time.Hour) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(got, []byte("Value")) { + t.Fatalf("expected: %v, got: %v", "Value", got) + } + + got, _, ok := store.Get([]byte("Key")) + if !ok { + t.Fatalf("expected key to exist") + } + if !bytes.Equal(got, []byte("Value")) { + t.Fatalf("expected: %v, got: %v", "Value", got) + } + }) + + t.Run("Cache Hit", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + + store.Set([]byte("Key"), []byte("Value"), 1*time.Hour) + + factoryFunc := func() ([]byte, error) { + return []byte("NewValue"), nil + } + + got, err := store.Memoize([]byte("Key"), factoryFunc, 1*time.Hour) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(got, []byte("Value")) { + t.Fatalf("expected: %v, got: %v", "Value", got) + } + }) +} + +func TestStoreCleanup(t *testing.T) { + t.Parallel() + + t.Run("Cleanup Expired", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + + store.Set([]byte("1"), []byte("1"), 500*time.Millisecond) + store.Set([]byte("2"), []byte("2"), 1*time.Hour) + + time.Sleep(600 * time.Millisecond) + + store.Cleanup() + + if _, _, ok := store.Get([]byte("1")); ok { + t.Fatalf("expected 1 to not exist") + } + + if _, _, ok := store.Get([]byte("2")); !ok { + t.Fatalf("expected 2 to exist") + } + }) + + t.Run("No Cleanup", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + + store.Set([]byte("Key"), []byte("Value"), 1*time.Hour) + + // No cleanup should occur + store.Cleanup() + + if _, _, ok := store.Get([]byte("Key")); !ok { + t.Fatalf("expected key to exist") + } + }) +} + +func TestStoreEvict(t *testing.T) { + t.Parallel() + + t.Run("Evict FIFO", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + if err := store.Policy.SetPolicy(PolicyFIFO); err != nil { + t.Fatalf("unexpected error: %v", err) + } + store.MaxCost = 5 + + store.Set([]byte("1"), []byte("1"), 0) + store.Set([]byte("2"), []byte("2"), 0) + + // Trigger eviction + store.Set([]byte("3"), []byte("3"), 0) + store.Evict() + + if _, _, ok := store.Get([]byte("1")); ok { + t.Fatalf("expected key 1 to not exist") + } + + if _, _, ok := store.Get([]byte("2")); !ok { + t.Fatalf("expected key 2 to exist") + } + }) + + t.Run("No Evict", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + if err := store.Policy.SetPolicy(PolicyFIFO); err != nil { + t.Fatalf("unexpected error: %v", err) + } + store.MaxCost = 10 + + store.Set([]byte("1"), []byte("1"), 0) + store.Set([]byte("2"), []byte("2"), 0) + + // No eviction should occur + store.Set([]byte("3"), []byte("3"), 0) + store.Evict() + + if _, _, ok := store.Get([]byte("1")); !ok { + t.Fatalf("expected key 1 to exist") + } + + if _, _, ok := store.Get([]byte("2")); !ok { + t.Fatalf("expected key 2 to exist") + } + }) + + t.Run("No Evict PolicyNone", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + if err := store.Policy.SetPolicy(PolicyNone); err != nil { + t.Fatalf("unexpected error: %v", err) + } + store.MaxCost = 5 + + store.Set([]byte("1"), []byte("1"), 0) + store.Set([]byte("2"), []byte("2"), 0) + + // No eviction should occur + store.Set([]byte("3"), []byte("3"), 0) + store.Evict() + + if _, _, ok := store.Get([]byte("1")); !ok { + t.Fatalf("expected key 1 to exist") + } + + if _, _, ok := store.Get([]byte("2")); !ok { + t.Fatalf("expected key 2 to exist") + } + }) + + t.Run("No Evict MaxCost Zero", func(t *testing.T) { + t.Parallel() + + store := setupTestStore(t) + if err := store.Policy.SetPolicy(PolicyFIFO); err != nil { + t.Fatalf("unexpected error: %v", err) + } + store.MaxCost = 0 + + store.Set([]byte("1"), []byte("1"), 0) + store.Set([]byte("2"), []byte("2"), 0) + + store.Evict() + + if _, _, ok := store.Get([]byte("1")); !ok { + t.Fatalf("expected key 1 to exist") + } + + if _, _, ok := store.Get([]byte("2")); !ok { + t.Fatalf("expected key 2 to exist") + } + }) +} + func BenchmarkStoreGet(b *testing.B) { - for n := 1; n <= 10000; n *= 10 { - b.Run(strconv.Itoa(n), func(b *testing.B) { - want := setupTestStore(b) + policy := map[string]EvictionPolicyType{ + "None": PolicyNone, + "FIFO": PolicyFIFO, + "LRU": PolicyLRU, + "LFU": PolicyLFU, + "LTR": PolicyLTR, + } + for k, v := range policy { + b.Run(k, func(b *testing.B) { + for n := 1; n <= 10000; n *= 10 { + b.Run(strconv.Itoa(n), func(b *testing.B) { + want := setupTestStore(b) - for i := range n - 1 { - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, uint64(i)) - want.Set(buf, buf, 0) - } + if err := want.Policy.SetPolicy(v); err != nil { + b.Fatalf("unexpected error: %v", err) + } - key := []byte("Key") - want.Set(key, []byte("Store"), 0) - b.ReportAllocs() + for i := range n - 1 { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(i)) + want.Set(buf, buf, 0) + } - b.ResetTimer() + key := []byte("Key") + want.Set(key, []byte("Store"), 0) + b.ReportAllocs() - for b.Loop() { - want.Get(key) + b.ResetTimer() + + for b.Loop() { + want.Get(key) + } + }) } }) } } func BenchmarkStoreSet(b *testing.B) { - for n := 1; n <= 10000; n *= 10 { - b.Run(strconv.Itoa(n), func(b *testing.B) { - want := setupTestStore(b) + policy := map[string]EvictionPolicyType{ + "None": PolicyNone, + "FIFO": PolicyFIFO, + "LRU": PolicyLRU, + "LFU": PolicyLFU, + "LTR": PolicyLTR, + } + for k, v := range policy { + b.Run(k, func(b *testing.B) { + for n := 1; n <= 10000; n *= 10 { + b.Run(strconv.Itoa(n), func(b *testing.B) { + want := setupTestStore(b) - for i := range n - 1 { - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, uint64(i)) - want.Set(buf, buf, 0) + if err := want.Policy.SetPolicy(v); err != nil { + b.Fatalf("unexpected error: %v", err) + } + + for i := range n - 1 { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(i)) + want.Set(buf, buf, 0) + } + + key := []byte("Key") + store := []byte("Store") + + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + want.Set(key, store, 0) + } + }) } + }) + } +} - key := []byte("Key") - store := []byte("Store") +func BenchmarkStoreSetInsert(b *testing.B) { + policy := map[string]EvictionPolicyType{ + "None": PolicyNone, + "FIFO": PolicyFIFO, + "LRU": PolicyLRU, + "LFU": PolicyLFU, + "LTR": PolicyLTR, + } + for k, v := range policy { + b.Run(k, func(b *testing.B) { + for n := 1; n <= 10000; n *= 10 { + b.Run(strconv.Itoa(n), func(b *testing.B) { + want := setupTestStore(b) - b.ReportAllocs() - b.ResetTimer() + if err := want.Policy.SetPolicy(v); err != nil { + b.Fatalf("unexpected error: %v", err) + } - for b.Loop() { - want.Set(key, store, 0) + list := make([][]byte, n) + for i := range n { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(i)) + list = append(list, buf) + } + + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + for _, k := range list { + want.Set(k, k, 0) + } + } + }) } }) } |