130 lines
3.0 KiB
Go
130 lines
3.0 KiB
Go
package cache
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/samber/lo"
|
|
|
|
"base/pkg/store"
|
|
)
|
|
|
|
type Cache[V any] interface {
|
|
WithCache(ctx context.Context, key string, fn func(context.Context) (V, error), ttl time.Duration) (V, error)
|
|
WithHashCache(ctx context.Context, set string, key []string, fn func(context.Context, []string) (map[string]V, error), ttl time.Duration) (map[string]V, error)
|
|
InvalidateKeys(ctx context.Context, keys ...string) error
|
|
InvalidatePattern(ctx context.Context, pattern string) error
|
|
}
|
|
|
|
type cache[V any] struct {
|
|
store.Store[V]
|
|
}
|
|
|
|
func New[V any](store store.Store[V]) Cache[V] {
|
|
return cache[V]{store}
|
|
}
|
|
|
|
func (c cache[V]) WithCache(ctx context.Context, key string, fn func(context.Context) (V, error), ttl time.Duration) (V, error) {
|
|
result, found, err := c.Get(ctx, key)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
if found {
|
|
return result, nil
|
|
}
|
|
|
|
result, err = fn(ctx)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
err = c.Set(ctx, key, result, ttl)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (c cache[V]) WithHashCache(ctx context.Context, set string, keys []string, fn func(context.Context, []string) (map[string]V, error), ttl time.Duration) (map[string]V, error) {
|
|
fetchResult := make(map[string]V, len(keys))
|
|
var missKeys []string
|
|
var getError error
|
|
|
|
// step 1 try to get from redis and figure out missedKeys
|
|
// for when there are no keys ignore cache retrieve all from source
|
|
if len(keys) > 0 {
|
|
fetchResult, missKeys, getError = c.get(ctx, set, keys)
|
|
if getError != nil {
|
|
return nil, getError
|
|
}
|
|
|
|
// all target key founded
|
|
if len(missKeys) == 0 {
|
|
return fetchResult, nil
|
|
}
|
|
}
|
|
|
|
//fetch missedKeys from source
|
|
newResult, fnErr := fn(ctx, missKeys)
|
|
if fnErr != nil {
|
|
return nil, fnErr
|
|
}
|
|
|
|
// append new result to fetchResult
|
|
for key, val := range newResult {
|
|
fetchResult[key] = val
|
|
}
|
|
|
|
// set new founded keys
|
|
setErr := c.HMSet(ctx, set, newResult, ttl)
|
|
if setErr != nil {
|
|
return nil, setErr
|
|
}
|
|
|
|
return fetchResult, nil
|
|
}
|
|
|
|
func (c cache[V]) get(ctx context.Context, setKey string, keys []string) (map[string]V, []string, error) {
|
|
fetchResult, fetchErr := c.HMGet(ctx, setKey, keys...)
|
|
if fetchErr != nil {
|
|
return nil, nil, fetchErr
|
|
}
|
|
|
|
if len(fetchResult) == len(keys) {
|
|
return fetchResult, nil, nil
|
|
}
|
|
|
|
if len(fetchResult) == 0 {
|
|
// just for avoid nil panic in higher layer
|
|
fetchResult = make(map[string]V, len(keys))
|
|
}
|
|
|
|
// found miss key for fetch from source in higher level
|
|
missKeys := lo.Filter(keys, func(item string, index int) bool { return !lo.HasKey(fetchResult, item) })
|
|
|
|
return fetchResult, missKeys, nil
|
|
}
|
|
|
|
func (c cache[V]) InvalidateKeys(ctx context.Context, keys ...string) error {
|
|
if len(keys) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if err := c.Store.DeleteMultiple(ctx, keys...); err != nil {
|
|
return fmt.Errorf("failed to invalidate keys: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c cache[V]) InvalidatePattern(ctx context.Context, pattern string) error {
|
|
if err := c.Store.Delete(ctx, pattern); err != nil {
|
|
return fmt.Errorf("failed to invalidate pattern: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|