Files
base/pkg/cache/cache.go
2026-04-10 18:25:21 +03:30

130 lines
3.0 KiB
Go

package cache
import (
"context"
"fmt"
"time"
"github.com/samber/lo"
"base/pkg/store"
)
type Cache[V any] interface {
WithCache(ctx context.Context, key string, fn func(context.Context) (V, error), ttl time.Duration) (V, error)
WithHashCache(ctx context.Context, set string, key []string, fn func(context.Context, []string) (map[string]V, error), ttl time.Duration) (map[string]V, error)
InvalidateKeys(ctx context.Context, keys ...string) error
InvalidatePattern(ctx context.Context, pattern string) error
}
type cache[V any] struct {
store.Store[V]
}
func New[V any](store store.Store[V]) Cache[V] {
return cache[V]{store}
}
func (c cache[V]) WithCache(ctx context.Context, key string, fn func(context.Context) (V, error), ttl time.Duration) (V, error) {
result, found, err := c.Get(ctx, key)
if err != nil {
return result, err
}
if found {
return result, nil
}
result, err = fn(ctx)
if err != nil {
return result, err
}
err = c.Set(ctx, key, result, ttl)
if err != nil {
return result, err
}
return result, nil
}
func (c cache[V]) WithHashCache(ctx context.Context, set string, keys []string, fn func(context.Context, []string) (map[string]V, error), ttl time.Duration) (map[string]V, error) {
fetchResult := make(map[string]V, len(keys))
var missKeys []string
var getError error
// step 1 try to get from redis and figure out missedKeys
// for when there are no keys ignore cache retrieve all from source
if len(keys) > 0 {
fetchResult, missKeys, getError = c.get(ctx, set, keys)
if getError != nil {
return nil, getError
}
// all target key founded
if len(missKeys) == 0 {
return fetchResult, nil
}
}
//fetch missedKeys from source
newResult, fnErr := fn(ctx, missKeys)
if fnErr != nil {
return nil, fnErr
}
// append new result to fetchResult
for key, val := range newResult {
fetchResult[key] = val
}
// set new founded keys
setErr := c.HMSet(ctx, set, newResult, ttl)
if setErr != nil {
return nil, setErr
}
return fetchResult, nil
}
func (c cache[V]) get(ctx context.Context, setKey string, keys []string) (map[string]V, []string, error) {
fetchResult, fetchErr := c.HMGet(ctx, setKey, keys...)
if fetchErr != nil {
return nil, nil, fetchErr
}
if len(fetchResult) == len(keys) {
return fetchResult, nil, nil
}
if len(fetchResult) == 0 {
// just for avoid nil panic in higher layer
fetchResult = make(map[string]V, len(keys))
}
// found miss key for fetch from source in higher level
missKeys := lo.Filter(keys, func(item string, index int) bool { return !lo.HasKey(fetchResult, item) })
return fetchResult, missKeys, nil
}
func (c cache[V]) InvalidateKeys(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
if err := c.Store.DeleteMultiple(ctx, keys...); err != nil {
return fmt.Errorf("failed to invalidate keys: %w", err)
}
return nil
}
func (c cache[V]) InvalidatePattern(ctx context.Context, pattern string) error {
if err := c.Store.Delete(ctx, pattern); err != nil {
return fmt.Errorf("failed to invalidate pattern: %w", err)
}
return nil
}