1 package redis_test
2
3 import (
4 "context"
5 "strconv"
6 "sync"
7
8 . "github.com/bsm/ginkgo/v2"
9 . "github.com/bsm/gomega"
10
11 "github.com/redis/go-redis/v9"
12 )
13
14 var _ = Describe("Tx", func() {
15 var client *redis.Client
16
17 BeforeEach(func() {
18 client = redis.NewClient(redisOptions())
19 Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
20 })
21
22 AfterEach(func() {
23 Expect(client.Close()).NotTo(HaveOccurred())
24 })
25
26 It("should Watch", func() {
27 var incr func(string) error
28
29
30 incr = func(key string) error {
31 err := client.Watch(ctx, func(tx *redis.Tx) error {
32 n, err := tx.Get(ctx, key).Int64()
33 if err != nil && err != redis.Nil {
34 return err
35 }
36
37 _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
38 pipe.Set(ctx, key, strconv.FormatInt(n+1, 10), 0)
39 return nil
40 })
41 return err
42 }, key)
43 if err == redis.TxFailedErr {
44 return incr(key)
45 }
46 return err
47 }
48
49 var wg sync.WaitGroup
50 for i := 0; i < 100; i++ {
51 wg.Add(1)
52 go func() {
53 defer GinkgoRecover()
54 defer wg.Done()
55
56 err := incr("key")
57 Expect(err).NotTo(HaveOccurred())
58 }()
59 }
60 wg.Wait()
61
62 n, err := client.Get(ctx, "key").Int64()
63 Expect(err).NotTo(HaveOccurred())
64 Expect(n).To(Equal(int64(100)))
65 })
66
67 It("should discard", Label("NonRedisEnterprise"), func() {
68 err := client.Watch(ctx, func(tx *redis.Tx) error {
69 cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
70 pipe.Set(ctx, "key1", "hello1", 0)
71 pipe.Discard()
72 pipe.Set(ctx, "key2", "hello2", 0)
73 return nil
74 })
75 Expect(err).NotTo(HaveOccurred())
76 Expect(cmds).To(HaveLen(1))
77 return err
78 }, "key1", "key2")
79 Expect(err).NotTo(HaveOccurred())
80
81 get := client.Get(ctx, "key1")
82 Expect(get.Err()).To(Equal(redis.Nil))
83 Expect(get.Val()).To(Equal(""))
84
85 get = client.Get(ctx, "key2")
86 Expect(get.Err()).NotTo(HaveOccurred())
87 Expect(get.Val()).To(Equal("hello2"))
88 })
89
90 It("returns no error when there are no commands", func() {
91 err := client.Watch(ctx, func(tx *redis.Tx) error {
92 _, err := tx.TxPipelined(ctx, func(redis.Pipeliner) error { return nil })
93 return err
94 })
95 Expect(err).NotTo(HaveOccurred())
96
97 v, err := client.Ping(ctx).Result()
98 Expect(err).NotTo(HaveOccurred())
99 Expect(v).To(Equal("PONG"))
100 })
101
102 It("should exec bulks", func() {
103 const N = 20000
104
105 err := client.Watch(ctx, func(tx *redis.Tx) error {
106 cmds, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
107 for i := 0; i < N; i++ {
108 pipe.Incr(ctx, "key")
109 }
110 return nil
111 })
112 Expect(err).NotTo(HaveOccurred())
113 Expect(len(cmds)).To(Equal(N))
114 for _, cmd := range cmds {
115 Expect(cmd.Err()).NotTo(HaveOccurred())
116 }
117 return err
118 })
119 Expect(err).NotTo(HaveOccurred())
120
121 num, err := client.Get(ctx, "key").Int64()
122 Expect(err).NotTo(HaveOccurred())
123 Expect(num).To(Equal(int64(N)))
124 })
125
126 It("should recover from bad connection", func() {
127
128 cn, err := client.Pool().Get(context.Background())
129 Expect(err).NotTo(HaveOccurred())
130
131 cn.SetNetConn(&badConn{})
132 client.Pool().Put(ctx, cn)
133
134 do := func() error {
135 err := client.Watch(ctx, func(tx *redis.Tx) error {
136 _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
137 pipe.Ping(ctx)
138 return nil
139 })
140 return err
141 })
142 return err
143 }
144
145 err = do()
146 Expect(err).NotTo(HaveOccurred())
147 })
148 })
149
View as plain text