package cache import ( "context" "fmt" "time" "github.com/samber/lo" "base/pkg/store" ) type Cache[V any] interface { WithCache(ctx context.Context, key string, fn func(context.Context) (V, error), ttl time.Duration) (V, error) WithHashCache(ctx context.Context, set string, key []string, fn func(context.Context, []string) (map[string]V, error), ttl time.Duration) (map[string]V, error) InvalidateKeys(ctx context.Context, keys ...string) error InvalidatePattern(ctx context.Context, pattern string) error } type cache[V any] struct { store.Store[V] } func New[V any](store store.Store[V]) Cache[V] { return cache[V]{store} } func (c cache[V]) WithCache(ctx context.Context, key string, fn func(context.Context) (V, error), ttl time.Duration) (V, error) { result, found, err := c.Get(ctx, key) if err != nil { return result, err } if found { return result, nil } result, err = fn(ctx) if err != nil { return result, err } err = c.Set(ctx, key, result, ttl) if err != nil { return result, err } return result, nil } func (c cache[V]) WithHashCache(ctx context.Context, set string, keys []string, fn func(context.Context, []string) (map[string]V, error), ttl time.Duration) (map[string]V, error) { fetchResult := make(map[string]V, len(keys)) var missKeys []string var getError error // step 1 try to get from redis and figure out missedKeys // for when there are no keys ignore cache retrieve all from source if len(keys) > 0 { fetchResult, missKeys, getError = c.get(ctx, set, keys) if getError != nil { return nil, getError } // all target key founded if len(missKeys) == 0 { return fetchResult, nil } } //fetch missedKeys from source newResult, fnErr := fn(ctx, missKeys) if fnErr != nil { return nil, fnErr } // append new result to fetchResult for key, val := range newResult { fetchResult[key] = val } // set new founded keys setErr := c.HMSet(ctx, set, newResult, ttl) if setErr != nil { return nil, setErr } return fetchResult, nil } func (c cache[V]) get(ctx context.Context, setKey string, keys []string) (map[string]V, []string, error) { fetchResult, fetchErr := c.HMGet(ctx, setKey, keys...) if fetchErr != nil { return nil, nil, fetchErr } if len(fetchResult) == len(keys) { return fetchResult, nil, nil } if len(fetchResult) == 0 { // just for avoid nil panic in higher layer fetchResult = make(map[string]V, len(keys)) } // found miss key for fetch from source in higher level missKeys := lo.Filter(keys, func(item string, index int) bool { return !lo.HasKey(fetchResult, item) }) return fetchResult, missKeys, nil } func (c cache[V]) InvalidateKeys(ctx context.Context, keys ...string) error { if len(keys) == 0 { return nil } if err := c.Store.DeleteMultiple(ctx, keys...); err != nil { return fmt.Errorf("failed to invalidate keys: %w", err) } return nil } func (c cache[V]) InvalidatePattern(ctx context.Context, pattern string) error { if err := c.Store.Delete(ctx, pattern); err != nil { return fmt.Errorf("failed to invalidate pattern: %w", err) } return nil }