...
1 package redis_test
2
3 import (
4 "context"
5 "strings"
6 "sync"
7
8 "github.com/redis/go-redis/v9"
9 )
10
11
12 type commandRecorder struct {
13 mu sync.Mutex
14 commands []string
15 maxSize int
16 }
17
18
19 func newCommandRecorder(maxSize int) *commandRecorder {
20 return &commandRecorder{
21 commands: make([]string, 0, maxSize),
22 maxSize: maxSize,
23 }
24 }
25
26
27 func (r *commandRecorder) Record(cmd string) {
28 cmd = strings.ToLower(cmd)
29 r.mu.Lock()
30 defer r.mu.Unlock()
31
32 r.commands = append(r.commands, cmd)
33 if len(r.commands) > r.maxSize {
34 r.commands = r.commands[1:]
35 }
36 }
37
38
39 func (r *commandRecorder) LastCommands() []string {
40 r.mu.Lock()
41 defer r.mu.Unlock()
42 return append([]string(nil), r.commands...)
43 }
44
45
46 func (r *commandRecorder) Contains(cmd string) bool {
47 cmd = strings.ToLower(cmd)
48 r.mu.Lock()
49 defer r.mu.Unlock()
50 for _, c := range r.commands {
51 if strings.Contains(c, cmd) {
52 return true
53 }
54 }
55 return false
56 }
57
58
59 func (r *commandRecorder) Hook() redis.Hook {
60 return &commandHook{recorder: r}
61 }
62
63
64 type commandHook struct {
65 recorder *commandRecorder
66 }
67
68 func (h *commandHook) DialHook(next redis.DialHook) redis.DialHook {
69 return next
70 }
71
72 func (h *commandHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
73 return func(ctx context.Context, cmd redis.Cmder) error {
74 h.recorder.Record(cmd.String())
75 return next(ctx, cmd)
76 }
77 }
78
79 func (h *commandHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
80 return func(ctx context.Context, cmds []redis.Cmder) error {
81 for _, cmd := range cmds {
82 h.recorder.Record(cmd.String())
83 }
84 return next(ctx, cmds)
85 }
86 }
87
View as plain text