Files
base/pkg/store/redis.go
2026-04-10 18:25:21 +03:30

373 lines
9.1 KiB
Go

package store
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog"
"base/pkg/metrics"
)
// RedisStore implements Store interface using Redis
type RedisStore[V any] struct {
client *redis.Client
logger *zerolog.Logger
metrics *metrics.Metrics
}
// NewRedisStore creates a new Redis store instance
func NewRedisStore[V any](client *redis.Client, logger *zerolog.Logger, metrics *metrics.Metrics) Store[V] {
return &RedisStore[V]{
client: client,
logger: logger,
metrics: metrics,
}
}
// Get retrieves a value from store by key
func (c *RedisStore[V]) Get(ctx context.Context, key string) (V, bool, error) {
var zero V
keyPattern, err := extractKeyPattern(key)
if err != nil {
return zero, false, err
}
start := time.Now()
dest, exist, getErr := c.get(ctx, key)
duration := time.Since(start)
c.metrics.RecordCacheHit("redis", keyPattern, "get", exist, getErr, duration)
return dest, exist, err
}
func (c *RedisStore[V]) get(ctx context.Context, key string) (V, bool, error) {
var zero V
val, err := c.client.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return zero, false, nil
}
return zero, false, fmt.Errorf("failed to get key %s: %w", key, err)
}
newDest := new(V)
// Try to unmarshal the value
if err = json.Unmarshal([]byte(val), newDest); err != nil {
return zero, false, fmt.Errorf("failed to unmarshal cached value for key %s: %w", key, err)
}
return *newDest, true, nil
}
// Set stores a value in store with expiration
func (c *RedisStore[V]) Set(ctx context.Context, key string, value V, expiration time.Duration) error {
return c.set(ctx, key, value, expiration)
}
func (c *RedisStore[V]) set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
data, err := marshalValue(key, value)
if err != nil {
return err
}
err = c.client.Set(ctx, key, data, expiration).Err()
if err != nil {
return fmt.Errorf("failed to set key %s: %w", key, err)
}
return nil
}
// Delete removes a key from store
func (c *RedisStore[V]) Delete(ctx context.Context, key string) error {
return c.client.Del(ctx, key).Err()
}
func (c *RedisStore[V]) delete(ctx context.Context, key string) error {
err := c.client.Del(ctx, key).Err()
if err != nil {
return fmt.Errorf("failed to delete key %s: %w", key, err)
}
return nil
}
// Exists checks if a key exists in store
func (c *RedisStore[V]) Exists(ctx context.Context, key string) (bool, error) {
keyPattern, err := extractKeyPattern(key)
if err != nil {
return false, err
}
start := time.Now()
exists, err := c.exists(ctx, key)
duration := time.Since(start)
c.metrics.RecordCacheHit("redis", keyPattern, "exists", exists, err, duration)
return exists, err
}
func (c *RedisStore[V]) exists(ctx context.Context, key string) (bool, error) {
exists, err := c.client.Exists(ctx, key).Result()
if err != nil {
return false, fmt.Errorf("failed to check existence of key %s: %w", key, err)
}
result := exists > 0
return result, nil
}
// SetNX sets a value only if the key doesn't exist (atomic operation)
func (c *RedisStore[V]) SetNX(ctx context.Context, key string, value V, expiration time.Duration) (bool, error) {
keyPattern, err := extractKeyPattern(key)
if err != nil {
return false, err
}
start := time.Now()
success, err := c.setNX(ctx, key, value, expiration)
duration := time.Since(start)
c.metrics.RecordCacheHit("redis", keyPattern, "setNx", success, err, duration)
return success, err
}
func (c *RedisStore[V]) setNX(ctx context.Context, key string, value V, expiration time.Duration) (bool, error) {
var data []byte
var err error
// Vry to marshal the value to JSON
if data, err = json.Marshal(value); err != nil {
return false, fmt.Errorf("failed to marshal value for key %s: %w", key, err)
}
success, err := c.client.SetNX(ctx, key, data, expiration).Result()
if err != nil {
return false, fmt.Errorf("failed to set key %s with NX: %w", key, err)
}
return success, nil
}
// HMGet retrieves multiple fields from a hash
func (c *RedisStore[V]) HMGet(ctx context.Context, key string, keys ...string) (map[string]V, error) {
keyPattern, err := extractKeyPattern(key)
if err != nil {
return nil, err
}
start := time.Now()
result, getErr := c.hmGet(ctx, key, keys...)
duration := time.Since(start)
c.metrics.RecordCacheHit("redis", keyPattern, "hmget", len(result) > 0 && getErr == nil, getErr, duration)
return result, err
}
func (c *RedisStore[V]) hmGet(ctx context.Context, key string, fields ...string) (map[string]V, error) {
vals, err := c.client.HMGet(ctx, key, fields...).Result()
if err != nil {
return nil, fmt.Errorf("failed to hmget key %s: %w", key, err)
}
result := make(map[string]V, len(fields))
for i, field := range fields {
if vals[i] != nil {
serializedValue, serializeValueErr := serializeValue[V](vals[i])
if serializeValueErr != nil {
return nil, serializeValueErr
}
result[field] = serializedValue
}
}
return result, nil
}
// HGetAll retrieves multiple fields from a hash
func (c *RedisStore[V]) HGetAll(ctx context.Context, key string) (map[string]V, error) {
keyPattern, err := extractKeyPattern(key)
if err != nil {
return nil, err
}
start := time.Now()
result, getErr := c.hGetAll(ctx, key)
duration := time.Since(start)
c.metrics.RecordCacheHit("redis", keyPattern, "hmget", len(result) > 0 && getErr == nil, getErr, duration)
return result, err
}
func (c *RedisStore[V]) hGetAll(ctx context.Context, key string) (map[string]V, error) {
vals, err := c.client.HGetAll(ctx, key).Result()
if err != nil {
return nil, fmt.Errorf("failed to hmget key %s: %w", key, err)
}
result := make(map[string]V)
for _, field := range vals {
serializedValue, serializeValueErr := serializeValue[V](field)
if serializeValueErr != nil {
return nil, serializeValueErr
}
result[field] = serializedValue
}
return result, nil
}
// HMSet sets multiple fields in a hash with expiration
func (c *RedisStore[V]) HMSet(ctx context.Context, key string, values map[string]V, expiration time.Duration) error {
return c.hmSet(ctx, key, values, expiration)
}
func (c *RedisStore[V]) hmSet(ctx context.Context, key string, values map[string]V, expiration time.Duration) error {
if len(values) == 0 {
return nil
}
// Convert values to string format for Redis hash
hashValues := make(map[string]interface{}, len(values))
for field, value := range values {
serializedValue, err := json.Marshal(value)
if err != nil {
return err
}
hashValues[field] = serializedValue
}
// Set hash fields
err := c.client.HMSet(ctx, key, hashValues).Err()
if err != nil {
return fmt.Errorf("failed to hmset key %s: %w", key, err)
}
// Set expiration if specified
if expiration > 0 {
err = c.client.Expire(ctx, key, expiration).Err()
if err != nil {
return fmt.Errorf("failed to set expiration for key %s: %w", key, err)
}
}
return nil
}
// SetMultiple stores multiple key-value pairs with expiration
func (c *RedisStore[V]) SetMultiple(ctx context.Context, items map[string]V, expiration time.Duration) error {
if len(items) == 0 {
return nil
}
return c.setMultiple(ctx, items, expiration)
}
func (c *RedisStore[V]) setMultiple(ctx context.Context, items map[string]V, expiration time.Duration) error {
pipe := c.client.Pipeline()
for key, value := range items {
data, err := marshalValue(key, value)
if err != nil {
return err
}
pipe.Set(ctx, key, data, expiration)
}
_, err := pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("failed to set multiple keys: %w", err)
}
return nil
}
func marshalValue(key string, value interface{}) ([]byte, error) {
data, err := json.Marshal(value)
if err != nil {
if str, ok := value.(string); ok {
return []byte(str), nil
}
return nil, fmt.Errorf("failed to marshal value for key %s: %w", key, err)
}
return data, nil
}
// DeleteMultiple removes multiple keys from store
func (c *RedisStore[V]) DeleteMultiple(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
return c.deleteMultiple(ctx, keys...)
}
func (c *RedisStore[V]) deleteMultiple(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
err := c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("failed to delete multiple keys: %w", err)
}
return nil
}
// DeletePattern removes all keys matching the pattern from store
func (c *RedisStore[V]) DeletePattern(ctx context.Context, pattern string) error {
return c.deletePattern(ctx, pattern)
}
func (c *RedisStore[V]) deletePattern(ctx context.Context, pattern string) error {
var cursor uint64
for {
var keys []string
var err error
// Use SCAN to find keys matching the pattern (non-blocking)
keys, cursor, err = c.client.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return fmt.Errorf("failed to scan keys with pattern %s: %w", pattern, err)
}
// Delete found keys
if len(keys) > 0 {
err = c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("failed to delete keys matching pattern %s: %w", pattern, err)
}
}
// If cursor is 0, we've scanned all keys
if cursor == 0 {
break
}
}
return nil
}