1 package redis
2
3 import (
4 "context"
5 "fmt"
6 "reflect"
7 "sync"
8 "sync/atomic"
9 "testing"
10 "time"
11
12 "github.com/redis/go-redis/v9/internal/pool"
13 "github.com/redis/go-redis/v9/internal/proto"
14
15 . "github.com/bsm/ginkgo/v2"
16 . "github.com/bsm/gomega"
17 )
18
19 var _ = Describe("newClusterState", func() {
20 var state *clusterState
21
22 createClusterState := func(slots []ClusterSlot) *clusterState {
23 opt := &ClusterOptions{}
24 opt.init()
25 nodes := newClusterNodes(opt)
26 state, err := newClusterState(nodes, slots, "10.10.10.10:1234")
27 Expect(err).NotTo(HaveOccurred())
28 return state
29 }
30
31 Describe("sorting", func() {
32 BeforeEach(func() {
33 state = createClusterState([]ClusterSlot{{
34 Start: 1000,
35 End: 1999,
36 }, {
37 Start: 0,
38 End: 999,
39 }, {
40 Start: 2000,
41 End: 2999,
42 }})
43 })
44
45 It("sorts slots", func() {
46 Expect(state.slots).To(Equal([]*clusterSlot{
47 {start: 0, end: 999, nodes: nil},
48 {start: 1000, end: 1999, nodes: nil},
49 {start: 2000, end: 2999, nodes: nil},
50 }))
51 })
52 })
53
54 Describe("loopback", func() {
55 BeforeEach(func() {
56 state = createClusterState([]ClusterSlot{{
57 Nodes: []ClusterNode{{Addr: "127.0.0.1:7001"}},
58 }, {
59 Nodes: []ClusterNode{{Addr: "127.0.0.1:7002"}},
60 }, {
61 Nodes: []ClusterNode{{Addr: "1.2.3.4:1234"}},
62 }, {
63 Nodes: []ClusterNode{{Addr: ":1234"}},
64 }})
65 })
66
67 It("replaces loopback hosts in addresses", func() {
68 slotAddr := func(slot *clusterSlot) string {
69 return slot.nodes[0].Client.Options().Addr
70 }
71
72 Expect(slotAddr(state.slots[0])).To(Equal("10.10.10.10:7001"))
73 Expect(slotAddr(state.slots[1])).To(Equal("10.10.10.10:7002"))
74 Expect(slotAddr(state.slots[2])).To(Equal("1.2.3.4:1234"))
75 Expect(slotAddr(state.slots[3])).To(Equal(":1234"))
76 })
77 })
78 })
79
80 type fixedHash string
81
82 func (h fixedHash) Get(string) string {
83 return string(h)
84 }
85
86 func TestRingSetAddrsAndRebalanceRace(t *testing.T) {
87 const (
88 ringShard1Name = "ringShardOne"
89 ringShard2Name = "ringShardTwo"
90
91 ringShard1Port = "6390"
92 ringShard2Port = "6391"
93 )
94
95 ring := NewRing(&RingOptions{
96 Addrs: map[string]string{
97 ringShard1Name: ":" + ringShard1Port,
98 },
99
100 HeartbeatFrequency: 1 * time.Hour,
101 NewConsistentHash: func(shards []string) ConsistentHash {
102 switch len(shards) {
103 case 1:
104 return fixedHash(ringShard1Name)
105 case 2:
106 return fixedHash(ringShard2Name)
107 default:
108 t.Fatalf("Unexpected number of shards: %v", shards)
109 return nil
110 }
111 },
112 })
113 defer ring.Close()
114
115
116 updatesDone := make(chan struct{})
117 defer func() { close(updatesDone) }()
118 go func() {
119 for i := 0; ; i++ {
120 select {
121 case <-updatesDone:
122 return
123 default:
124 if i%2 == 0 {
125 ring.SetAddrs(map[string]string{
126 ringShard1Name: ":" + ringShard1Port,
127 })
128 } else {
129 ring.SetAddrs(map[string]string{
130 ringShard1Name: ":" + ringShard1Port,
131 ringShard2Name: ":" + ringShard2Port,
132 })
133 }
134 }
135 }
136 }()
137
138 timer := time.NewTimer(1 * time.Second)
139 for running := true; running; {
140 select {
141 case <-timer.C:
142 running = false
143 default:
144 shard, err := ring.sharding.GetByKey("whatever")
145 if err == nil && shard == nil {
146 t.Fatal("shard is nil")
147 }
148 }
149 }
150 }
151
152 func BenchmarkRingShardingRebalanceLocked(b *testing.B) {
153 opts := &RingOptions{
154 Addrs: make(map[string]string),
155
156 HeartbeatFrequency: 1 * time.Hour,
157 }
158 for i := 0; i < 100; i++ {
159 opts.Addrs[fmt.Sprintf("shard%d", i)] = fmt.Sprintf(":63%02d", i)
160 }
161
162 ring := NewRing(opts)
163 defer ring.Close()
164
165 b.ResetTimer()
166 for i := 0; i < b.N; i++ {
167 ring.sharding.rebalanceLocked()
168 }
169 }
170
171 type testCounter struct {
172 mu sync.Mutex
173 t *testing.T
174 m map[string]int
175 }
176
177 func newTestCounter(t *testing.T) *testCounter {
178 return &testCounter{t: t, m: make(map[string]int)}
179 }
180
181 func (ct *testCounter) increment(key string) {
182 ct.mu.Lock()
183 defer ct.mu.Unlock()
184 ct.m[key]++
185 }
186
187 func (ct *testCounter) expect(values map[string]int) {
188 ct.mu.Lock()
189 defer ct.mu.Unlock()
190 ct.t.Helper()
191 if !reflect.DeepEqual(values, ct.m) {
192 ct.t.Errorf("expected %v != actual %v", values, ct.m)
193 }
194 }
195
196 func TestRingShardsCleanup(t *testing.T) {
197 const (
198 ringShard1Name = "ringShardOne"
199 ringShard2Name = "ringShardTwo"
200
201 ringShard1Addr = "shard1.test"
202 ringShard2Addr = "shard2.test"
203 )
204
205 t.Run("closes unused shards", func(t *testing.T) {
206 closeCounter := newTestCounter(t)
207
208 ring := NewRing(&RingOptions{
209 Addrs: map[string]string{
210 ringShard1Name: ringShard1Addr,
211 ringShard2Name: ringShard2Addr,
212 },
213 NewClient: func(opt *Options) *Client {
214 c := NewClient(opt)
215 c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
216 closeCounter.increment(opt.Addr)
217 return nil
218 })
219 return c
220 },
221 })
222 closeCounter.expect(map[string]int{})
223
224
225 ring.SetAddrs(map[string]string{
226 ringShard1Name: ringShard1Addr,
227 ringShard2Name: ringShard2Addr,
228 })
229 closeCounter.expect(map[string]int{})
230
231 ring.SetAddrs(map[string]string{
232 ringShard1Name: ringShard1Addr,
233 })
234 closeCounter.expect(map[string]int{ringShard2Addr: 1})
235
236 ring.SetAddrs(map[string]string{
237 ringShard2Name: ringShard2Addr,
238 })
239 closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
240
241 ring.Close()
242 closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 2})
243 })
244
245 t.Run("closes created shards if ring was closed", func(t *testing.T) {
246 createCounter := newTestCounter(t)
247 closeCounter := newTestCounter(t)
248
249 var (
250 ring *Ring
251 shouldClose int32
252 )
253
254 ring = NewRing(&RingOptions{
255 Addrs: map[string]string{
256 ringShard1Name: ringShard1Addr,
257 },
258 NewClient: func(opt *Options) *Client {
259 if atomic.LoadInt32(&shouldClose) != 0 {
260 ring.Close()
261 }
262 createCounter.increment(opt.Addr)
263 c := NewClient(opt)
264 c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
265 closeCounter.increment(opt.Addr)
266 return nil
267 })
268 return c
269 },
270 })
271 createCounter.expect(map[string]int{ringShard1Addr: 1})
272 closeCounter.expect(map[string]int{})
273
274 atomic.StoreInt32(&shouldClose, 1)
275
276 ring.SetAddrs(map[string]string{
277 ringShard2Name: ringShard2Addr,
278 })
279 createCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
280 closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
281 })
282 }
283
284
285
286 type timeoutErr struct {
287 error
288 }
289
290 func (e timeoutErr) Timeout() bool {
291 return true
292 }
293
294 func (e timeoutErr) Temporary() bool {
295 return true
296 }
297
298 func (e timeoutErr) Error() string {
299 return "i/o timeout"
300 }
301
302 var _ = Describe("withConn", func() {
303 var client *Client
304
305 BeforeEach(func() {
306 client = NewClient(&Options{
307 PoolSize: 1,
308 })
309 })
310
311 AfterEach(func() {
312 client.Close()
313 })
314
315 It("should replace the connection in the pool when there is no error", func() {
316 var conn *pool.Conn
317
318 client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
319 conn = c
320 return nil
321 })
322
323 newConn, err := client.connPool.Get(ctx)
324 Expect(err).To(BeNil())
325 Expect(newConn).To(Equal(conn))
326 })
327
328 It("should replace the connection in the pool when there is an error not related to a bad connection", func() {
329 var conn *pool.Conn
330
331 client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
332 conn = c
333 return proto.RedisError("LOADING")
334 })
335
336 newConn, err := client.connPool.Get(ctx)
337 Expect(err).To(BeNil())
338 Expect(newConn).To(Equal(conn))
339 })
340
341 It("should remove the connection from the pool when it times out", func() {
342 var conn *pool.Conn
343
344 client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
345 conn = c
346 return timeoutErr{}
347 })
348
349 newConn, err := client.connPool.Get(ctx)
350 Expect(err).To(BeNil())
351 Expect(newConn).NotTo(Equal(conn))
352 Expect(client.connPool.Len()).To(Equal(1))
353 })
354 })
355
356 var _ = Describe("ClusterClient", func() {
357 var client *ClusterClient
358
359 BeforeEach(func() {
360 client = &ClusterClient{}
361 })
362
363 Describe("cmdSlot", func() {
364 It("select slot from args for GETKEYSINSLOT command", func() {
365 cmd := NewStringSliceCmd(ctx, "cluster", "getkeysinslot", 100, 200)
366
367 slot := client.cmdSlot(cmd, -1)
368 Expect(slot).To(Equal(100))
369 })
370
371 It("select slot from args for COUNTKEYSINSLOT command", func() {
372 cmd := NewStringSliceCmd(ctx, "cluster", "countkeysinslot", 100)
373
374 slot := client.cmdSlot(cmd, -1)
375 Expect(slot).To(Equal(100))
376 })
377
378 It("follows preferred random slot", func() {
379 cmd := NewStatusCmd(ctx, "ping")
380
381 slot := client.cmdSlot(cmd, 101)
382 Expect(slot).To(Equal(101))
383 })
384 })
385 })
386
387 var _ = Describe("isLoopback", func() {
388 DescribeTable("should correctly identify loopback addresses",
389 func(host string, expected bool) {
390 result := isLoopback(host)
391 Expect(result).To(Equal(expected))
392 },
393
394 Entry("IPv4 loopback", "127.0.0.1", true),
395 Entry("IPv6 loopback", "::1", true),
396 Entry("IPv4 non-loopback", "192.168.1.1", false),
397 Entry("IPv6 non-loopback", "2001:db8::1", false),
398
399
400 Entry("localhost lowercase", "localhost", true),
401 Entry("localhost uppercase", "LOCALHOST", true),
402 Entry("localhost mixed case", "LocalHost", true),
403
404
405 Entry("host.docker.internal", "host.docker.internal", true),
406 Entry("HOST.DOCKER.INTERNAL", "HOST.DOCKER.INTERNAL", true),
407 Entry("custom.docker.internal", "custom.docker.internal", true),
408 Entry("app.docker.internal", "app.docker.internal", true),
409
410
411 Entry("redis hostname", "redis-cluster", false),
412 Entry("FQDN", "redis.example.com", false),
413 Entry("docker but not internal", "redis.docker.com", false),
414
415
416 Entry("empty string", "", false),
417 Entry("invalid IP", "256.256.256.256", false),
418 Entry("partial docker internal", "docker.internal", false),
419 )
420 })
421
View as plain text