1 package redis_test
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "net"
9 "sync"
10 "testing"
11 "time"
12
13 . "github.com/bsm/ginkgo/v2"
14 . "github.com/bsm/gomega"
15
16 "github.com/redis/go-redis/v9"
17 "github.com/redis/go-redis/v9/auth"
18 )
19
20 type redisHookError struct{}
21
22 var _ redis.Hook = redisHookError{}
23
24 func (redisHookError) DialHook(hook redis.DialHook) redis.DialHook {
25 return hook
26 }
27
28 func (redisHookError) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
29 return func(ctx context.Context, cmd redis.Cmder) error {
30 return errors.New("hook error")
31 }
32 }
33
34 func (redisHookError) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
35 return hook
36 }
37
38 func TestHookError(t *testing.T) {
39 rdb := redis.NewClient(&redis.Options{
40 Addr: ":6379",
41 })
42 rdb.AddHook(redisHookError{})
43
44 err := rdb.Ping(ctx).Err()
45 if err == nil {
46 t.Fatalf("got nil, expected an error")
47 }
48
49 wanted := "hook error"
50 if err.Error() != wanted {
51 t.Fatalf(`got %q, wanted %q`, err, wanted)
52 }
53 }
54
55
56
57 var _ = Describe("Client", func() {
58 var client *redis.Client
59
60 BeforeEach(func() {
61 client = redis.NewClient(redisOptions())
62 Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
63 })
64
65 AfterEach(func() {
66 client.Close()
67 })
68
69 It("should Stringer", func() {
70 Expect(client.String()).To(Equal(fmt.Sprintf("Redis<:%s db:0>", redisPort)))
71 })
72
73 It("supports context", func() {
74 ctx, cancel := context.WithCancel(ctx)
75 cancel()
76
77 err := client.Ping(ctx).Err()
78 Expect(err).To(MatchError("context canceled"))
79 })
80
81 It("supports WithTimeout", Label("NonRedisEnterprise"), func() {
82 err := client.ClientPause(ctx, time.Second).Err()
83 Expect(err).NotTo(HaveOccurred())
84
85 err = client.WithTimeout(10 * time.Millisecond).Ping(ctx).Err()
86 Expect(err).To(HaveOccurred())
87
88 err = client.Ping(ctx).Err()
89 Expect(err).NotTo(HaveOccurred())
90 })
91
92 It("do", func() {
93 val, err := client.Do(ctx, "ping").Result()
94 Expect(err).NotTo(HaveOccurred())
95 Expect(val).To(Equal("PONG"))
96 })
97
98 It("should ping", func() {
99 val, err := client.Ping(ctx).Result()
100 Expect(err).NotTo(HaveOccurred())
101 Expect(val).To(Equal("PONG"))
102 })
103
104 It("should return pool stats", func() {
105 Expect(client.PoolStats()).To(BeAssignableToTypeOf(&redis.PoolStats{}))
106 })
107
108 It("should support custom dialers", func() {
109 custom := redis.NewClient(&redis.Options{
110 Network: "tcp",
111 Addr: redisAddr,
112 Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
113 var d net.Dialer
114 return d.DialContext(ctx, network, addr)
115 },
116 })
117
118 val, err := custom.Ping(ctx).Result()
119 Expect(err).NotTo(HaveOccurred())
120 Expect(val).To(Equal("PONG"))
121 Expect(custom.Close()).NotTo(HaveOccurred())
122 })
123
124 It("should close", func() {
125 Expect(client.Close()).NotTo(HaveOccurred())
126 err := client.Ping(ctx).Err()
127 Expect(err).To(MatchError("redis: client is closed"))
128 })
129
130 It("should close pubsub without closing the client", func() {
131 pubsub := client.Subscribe(ctx)
132 Expect(pubsub.Close()).NotTo(HaveOccurred())
133
134 _, err := pubsub.Receive(ctx)
135 Expect(err).To(MatchError("redis: client is closed"))
136 Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
137 })
138
139 It("should close Tx without closing the client", func() {
140 err := client.Watch(ctx, func(tx *redis.Tx) error {
141 _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
142 pipe.Ping(ctx)
143 return nil
144 })
145 return err
146 })
147 Expect(err).NotTo(HaveOccurred())
148
149 Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
150 })
151
152 It("should close pubsub when client is closed", func() {
153 pubsub := client.Subscribe(ctx)
154 Expect(client.Close()).NotTo(HaveOccurred())
155
156 _, err := pubsub.Receive(ctx)
157 Expect(err).To(MatchError("redis: client is closed"))
158
159 Expect(pubsub.Close()).NotTo(HaveOccurred())
160 })
161
162 It("should select DB", Label("NonRedisEnterprise"), func() {
163 db2 := redis.NewClient(&redis.Options{
164 Addr: redisAddr,
165 DB: 2,
166 })
167 Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
168 Expect(db2.Get(ctx, "db").Err()).To(Equal(redis.Nil))
169 Expect(db2.Set(ctx, "db", 2, 0).Err()).NotTo(HaveOccurred())
170
171 n, err := db2.Get(ctx, "db").Int64()
172 Expect(err).NotTo(HaveOccurred())
173 Expect(n).To(Equal(int64(2)))
174
175 Expect(client.Get(ctx, "db").Err()).To(Equal(redis.Nil))
176
177 Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
178 Expect(db2.Close()).NotTo(HaveOccurred())
179 })
180
181 It("should client setname", func() {
182 opt := redisOptions()
183 opt.ClientName = "hi"
184 db := redis.NewClient(opt)
185
186 defer func() {
187 Expect(db.Close()).NotTo(HaveOccurred())
188 }()
189
190 Expect(db.Ping(ctx).Err()).NotTo(HaveOccurred())
191 val, err := db.ClientList(ctx).Result()
192 Expect(err).NotTo(HaveOccurred())
193 Expect(val).Should(ContainSubstring("name=hi"))
194 })
195
196 It("should attempt to set client name in HELLO", func() {
197 opt := redisOptions()
198 opt.ClientName = "hi"
199 db := redis.NewClient(opt)
200
201 defer func() {
202 Expect(db.Close()).NotTo(HaveOccurred())
203 }()
204
205
206 name, err := db.ClientGetName(ctx).Result()
207 Expect(err).NotTo(HaveOccurred())
208 Expect(name).Should(Equal("hi"))
209
210
211 conn := db.Conn()
212 hello, err := conn.Hello(ctx, 3, "", "", "hi2").Result()
213 Expect(err).NotTo(HaveOccurred())
214 Expect(hello["proto"]).Should(Equal(int64(3)))
215 name, err = conn.ClientGetName(ctx).Result()
216 Expect(err).NotTo(HaveOccurred())
217 Expect(name).Should(Equal("hi2"))
218 err = conn.Close()
219 Expect(err).NotTo(HaveOccurred())
220 })
221
222 It("should client PROTO 2", func() {
223 opt := redisOptions()
224 opt.Protocol = 2
225 db := redis.NewClient(opt)
226
227 defer func() {
228 Expect(db.Close()).NotTo(HaveOccurred())
229 }()
230
231 val, err := db.Do(ctx, "HELLO").Result()
232 Expect(err).NotTo(HaveOccurred())
233 Expect(val).Should(ContainElements("proto", int64(2)))
234 })
235
236 It("should client PROTO 3", func() {
237 opt := redisOptions()
238 db := redis.NewClient(opt)
239
240 defer func() {
241 Expect(db.Close()).NotTo(HaveOccurred())
242 }()
243
244 val, err := db.Do(ctx, "HELLO").Result()
245 Expect(err).NotTo(HaveOccurred())
246 Expect(val).Should(HaveKeyWithValue("proto", int64(3)))
247 })
248
249 It("processes custom commands", func() {
250 cmd := redis.NewCmd(ctx, "PING")
251 _ = client.Process(ctx, cmd)
252
253
254 Expect(client.Echo(ctx, "hello").Err()).NotTo(HaveOccurred())
255
256 Expect(cmd.Err()).NotTo(HaveOccurred())
257 Expect(cmd.Val()).To(Equal("PONG"))
258 })
259
260 It("should retry command on network error", func() {
261 Expect(client.Close()).NotTo(HaveOccurred())
262
263 client = redis.NewClient(&redis.Options{
264 Addr: redisAddr,
265 MaxRetries: 1,
266 })
267
268
269 cn, err := client.Pool().Get(ctx)
270 Expect(err).NotTo(HaveOccurred())
271
272 cn.SetNetConn(&badConn{})
273 client.Pool().Put(ctx, cn)
274
275 err = client.Ping(ctx).Err()
276 Expect(err).NotTo(HaveOccurred())
277 })
278
279 It("should retry with backoff", func() {
280 clientNoRetry := redis.NewClient(&redis.Options{
281 Addr: ":1234",
282 MaxRetries: -1,
283 })
284 defer clientNoRetry.Close()
285
286 clientRetry := redis.NewClient(&redis.Options{
287 Addr: ":1234",
288 MaxRetries: 5,
289 MaxRetryBackoff: 128 * time.Millisecond,
290 })
291 defer clientRetry.Close()
292
293 startNoRetry := time.Now()
294 err := clientNoRetry.Ping(ctx).Err()
295 Expect(err).To(HaveOccurred())
296 elapseNoRetry := time.Since(startNoRetry)
297
298 startRetry := time.Now()
299 err = clientRetry.Ping(ctx).Err()
300 Expect(err).To(HaveOccurred())
301 elapseRetry := time.Since(startRetry)
302
303 Expect(elapseRetry).To(BeNumerically(">", elapseNoRetry, 10*time.Millisecond))
304 })
305
306 It("should update conn.UsedAt on read/write", func() {
307 cn, err := client.Pool().Get(context.Background())
308 Expect(err).NotTo(HaveOccurred())
309 Expect(cn.UsedAt).NotTo(BeZero())
310
311
312
313
314
315 cn.SetUsedAt(time.Now().Add(-1 * time.Second))
316 createdAt := cn.UsedAt()
317
318 client.Pool().Put(ctx, cn)
319 Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue())
320
321 err = client.Ping(ctx).Err()
322 Expect(err).NotTo(HaveOccurred())
323
324 cn, err = client.Pool().Get(context.Background())
325 Expect(err).NotTo(HaveOccurred())
326 Expect(cn).NotTo(BeNil())
327 Expect(cn.UsedAt().After(createdAt)).To(BeTrue())
328 })
329
330 It("should process command with special chars", func() {
331 set := client.Set(ctx, "key", "hello1\r\nhello2\r\n", 0)
332 Expect(set.Err()).NotTo(HaveOccurred())
333 Expect(set.Val()).To(Equal("OK"))
334
335 get := client.Get(ctx, "key")
336 Expect(get.Err()).NotTo(HaveOccurred())
337 Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n"))
338 })
339
340 It("should handle big vals", func() {
341 bigVal := bytes.Repeat([]byte{'*'}, 2e6)
342
343 err := client.Set(ctx, "key", bigVal, 0).Err()
344 Expect(err).NotTo(HaveOccurred())
345
346
347 Expect(client.Close()).NotTo(HaveOccurred())
348 client = redis.NewClient(redisOptions())
349
350 got, err := client.Get(ctx, "key").Bytes()
351 Expect(err).NotTo(HaveOccurred())
352 Expect(got).To(Equal(bigVal))
353 })
354
355 It("should set and scan time", func() {
356 tm := time.Now()
357 err := client.Set(ctx, "now", tm, 0).Err()
358 Expect(err).NotTo(HaveOccurred())
359
360 var tm2 time.Time
361 err = client.Get(ctx, "now").Scan(&tm2)
362 Expect(err).NotTo(HaveOccurred())
363
364 Expect(tm2).To(BeTemporally("==", tm))
365 })
366
367 It("should set and scan durations", func() {
368 duration := 10 * time.Minute
369 err := client.Set(ctx, "duration", duration, 0).Err()
370 Expect(err).NotTo(HaveOccurred())
371
372 var duration2 time.Duration
373 err = client.Get(ctx, "duration").Scan(&duration2)
374 Expect(err).NotTo(HaveOccurred())
375
376 Expect(duration2).To(Equal(duration))
377 })
378
379 It("should Conn", func() {
380 err := client.Conn().Get(ctx, "this-key-does-not-exist").Err()
381 Expect(err).To(Equal(redis.Nil))
382 })
383
384 It("should set and scan net.IP", func() {
385 ip := net.ParseIP("192.168.1.1")
386 err := client.Set(ctx, "ip", ip, 0).Err()
387 Expect(err).NotTo(HaveOccurred())
388
389 var ip2 net.IP
390 err = client.Get(ctx, "ip").Scan(&ip2)
391 Expect(err).NotTo(HaveOccurred())
392
393 Expect(ip2).To(Equal(ip))
394 })
395 })
396
397 var _ = Describe("Client timeout", func() {
398 var opt *redis.Options
399 var client *redis.Client
400
401 AfterEach(func() {
402 Expect(client.Close()).NotTo(HaveOccurred())
403 })
404
405 testTimeout := func() {
406 It("SETINFO timeouts", func() {
407 conn := client.Conn()
408 err := conn.Ping(ctx).Err()
409 Expect(err).To(HaveOccurred())
410 Expect(err.(net.Error).Timeout()).To(BeTrue())
411 })
412
413 It("Ping timeouts", func() {
414 err := client.Ping(ctx).Err()
415 Expect(err).To(HaveOccurred())
416 Expect(err.(net.Error).Timeout()).To(BeTrue())
417 })
418
419 It("Pipeline timeouts", func() {
420 _, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
421 pipe.Ping(ctx)
422 return nil
423 })
424 Expect(err).To(HaveOccurred())
425 Expect(err.(net.Error).Timeout()).To(BeTrue())
426 })
427
428 It("Subscribe timeouts", func() {
429 if opt.WriteTimeout == 0 {
430 return
431 }
432
433 pubsub := client.Subscribe(ctx)
434 defer pubsub.Close()
435
436 err := pubsub.Subscribe(ctx, "_")
437 Expect(err).To(HaveOccurred())
438 Expect(err.(net.Error).Timeout()).To(BeTrue())
439 })
440
441 It("Tx timeouts", func() {
442 err := client.Watch(ctx, func(tx *redis.Tx) error {
443 return tx.Ping(ctx).Err()
444 })
445 Expect(err).To(HaveOccurred())
446 Expect(err.(net.Error).Timeout()).To(BeTrue())
447 })
448
449 It("Tx Pipeline timeouts", func() {
450 err := client.Watch(ctx, func(tx *redis.Tx) error {
451 _, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
452 pipe.Ping(ctx)
453 return nil
454 })
455 return err
456 })
457 Expect(err).To(HaveOccurred())
458 Expect(err.(net.Error).Timeout()).To(BeTrue())
459 })
460 }
461
462 Context("read timeout", func() {
463 BeforeEach(func() {
464 opt = redisOptions()
465 opt.ReadTimeout = time.Nanosecond
466 opt.WriteTimeout = -1
467 client = redis.NewClient(opt)
468 })
469
470 testTimeout()
471 })
472
473 Context("write timeout", func() {
474 BeforeEach(func() {
475 opt = redisOptions()
476 opt.ReadTimeout = -1
477 opt.WriteTimeout = time.Nanosecond
478 client = redis.NewClient(opt)
479 })
480
481 testTimeout()
482 })
483 })
484
485 var _ = Describe("Client OnConnect", func() {
486 var client *redis.Client
487
488 BeforeEach(func() {
489 opt := redisOptions()
490 opt.DB = 0
491 opt.OnConnect = func(ctx context.Context, cn *redis.Conn) error {
492 return cn.ClientSetName(ctx, "on_connect").Err()
493 }
494
495 client = redis.NewClient(opt)
496 })
497
498 AfterEach(func() {
499 Expect(client.Close()).NotTo(HaveOccurred())
500 })
501
502 It("calls OnConnect", func() {
503 name, err := client.ClientGetName(ctx).Result()
504 Expect(err).NotTo(HaveOccurred())
505 Expect(name).To(Equal("on_connect"))
506 })
507 })
508
509 var _ = Describe("Client context cancellation", func() {
510 var opt *redis.Options
511 var client *redis.Client
512
513 BeforeEach(func() {
514 opt = redisOptions()
515 opt.ReadTimeout = -1
516 opt.WriteTimeout = -1
517 client = redis.NewClient(opt)
518 })
519
520 AfterEach(func() {
521 Expect(client.Close()).NotTo(HaveOccurred())
522 })
523
524 It("Blocking operation cancellation", func() {
525 ctx, cancel := context.WithCancel(ctx)
526 cancel()
527
528 err := client.BLPop(ctx, 1*time.Second, "test").Err()
529 Expect(err).To(HaveOccurred())
530 Expect(err).To(BeIdenticalTo(context.Canceled))
531 })
532 })
533
534 var _ = Describe("Conn", func() {
535 var client *redis.Client
536
537 BeforeEach(func() {
538 client = redis.NewClient(redisOptions())
539 Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
540 })
541
542 AfterEach(func() {
543 err := client.Close()
544 Expect(err).NotTo(HaveOccurred())
545 })
546
547 It("TxPipeline", Label("NonRedisEnterprise"), func() {
548 tx := client.Conn().TxPipeline()
549 tx.SwapDB(ctx, 0, 2)
550 tx.SwapDB(ctx, 1, 0)
551 _, err := tx.Exec(ctx)
552 Expect(err).NotTo(HaveOccurred())
553 })
554 })
555
556 var _ = Describe("Hook", func() {
557 var client *redis.Client
558
559 BeforeEach(func() {
560 client = redis.NewClient(redisOptions())
561 Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
562 })
563
564 AfterEach(func() {
565 err := client.Close()
566 Expect(err).NotTo(HaveOccurred())
567 })
568
569 It("fifo", func() {
570 var res []string
571 client.AddHook(&hook{
572 processHook: func(hook redis.ProcessHook) redis.ProcessHook {
573 return func(ctx context.Context, cmd redis.Cmder) error {
574 res = append(res, "hook-1-process-start")
575 err := hook(ctx, cmd)
576 res = append(res, "hook-1-process-end")
577 return err
578 }
579 },
580 })
581 client.AddHook(&hook{
582 processHook: func(hook redis.ProcessHook) redis.ProcessHook {
583 return func(ctx context.Context, cmd redis.Cmder) error {
584 res = append(res, "hook-2-process-start")
585 err := hook(ctx, cmd)
586 res = append(res, "hook-2-process-end")
587 return err
588 }
589 },
590 })
591
592 err := client.Ping(ctx).Err()
593 Expect(err).NotTo(HaveOccurred())
594
595 Expect(res).To(Equal([]string{
596 "hook-1-process-start",
597 "hook-2-process-start",
598 "hook-2-process-end",
599 "hook-1-process-end",
600 }))
601 })
602
603 It("wrapped error in a hook", func() {
604 client.AddHook(&hook{
605 processHook: func(hook redis.ProcessHook) redis.ProcessHook {
606 return func(ctx context.Context, cmd redis.Cmder) error {
607 if err := hook(ctx, cmd); err != nil {
608 return fmt.Errorf("wrapped error: %w", err)
609 }
610 return nil
611 }
612 },
613 })
614 client.ScriptFlush(ctx)
615
616 script := redis.NewScript(`return 'Script and hook'`)
617
618 cmd := script.Run(ctx, client, nil)
619 Expect(cmd.Err()).NotTo(HaveOccurred())
620 Expect(cmd.Val()).To(Equal("Script and hook"))
621 })
622 })
623
624 var _ = Describe("Hook with MinIdleConns", func() {
625 var client *redis.Client
626
627 BeforeEach(func() {
628 options := redisOptions()
629 options.MinIdleConns = 1
630 client = redis.NewClient(options)
631 Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
632 })
633
634 AfterEach(func() {
635 err := client.Close()
636 Expect(err).NotTo(HaveOccurred())
637 })
638
639 It("fifo", func() {
640 var res []string
641 client.AddHook(&hook{
642 processHook: func(hook redis.ProcessHook) redis.ProcessHook {
643 return func(ctx context.Context, cmd redis.Cmder) error {
644 res = append(res, "hook-1-process-start")
645 err := hook(ctx, cmd)
646 res = append(res, "hook-1-process-end")
647 return err
648 }
649 },
650 })
651 client.AddHook(&hook{
652 processHook: func(hook redis.ProcessHook) redis.ProcessHook {
653 return func(ctx context.Context, cmd redis.Cmder) error {
654 res = append(res, "hook-2-process-start")
655 err := hook(ctx, cmd)
656 res = append(res, "hook-2-process-end")
657 return err
658 }
659 },
660 })
661
662 err := client.Ping(ctx).Err()
663 Expect(err).NotTo(HaveOccurred())
664
665 Expect(res).To(Equal([]string{
666 "hook-1-process-start",
667 "hook-2-process-start",
668 "hook-2-process-end",
669 "hook-1-process-end",
670 }))
671 })
672 })
673
674 var _ = Describe("Dialer connection timeouts", func() {
675 var client *redis.Client
676
677 const dialSimulatedDelay = 1 * time.Second
678
679 BeforeEach(func() {
680 options := redisOptions()
681 options.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
682
683
684 time.Sleep(dialSimulatedDelay)
685 return net.Dial("tcp", options.Addr)
686 }
687 options.MinIdleConns = 1
688 client = redis.NewClient(options)
689 })
690
691 AfterEach(func() {
692 err := client.Close()
693 Expect(err).NotTo(HaveOccurred())
694 })
695
696 It("does not contend on connection dial for concurrent commands", func() {
697 var wg sync.WaitGroup
698
699 const concurrency = 10
700
701 durations := make(chan time.Duration, concurrency)
702 errs := make(chan error, concurrency)
703
704 start := time.Now()
705 wg.Add(concurrency)
706
707 for i := 0; i < concurrency; i++ {
708 go func() {
709 defer wg.Done()
710
711 start := time.Now()
712 err := client.Ping(ctx).Err()
713 durations <- time.Since(start)
714 errs <- err
715 }()
716 }
717
718 wg.Wait()
719 close(durations)
720 close(errs)
721
722
723 for err := range errs {
724 Expect(err).NotTo(HaveOccurred())
725 }
726
727
728 for duration := range durations {
729 Expect(duration).To(BeNumerically("<", 2*dialSimulatedDelay))
730 }
731
732
733
734 Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
735 })
736 })
737
738 var _ = Describe("Credentials Provider Priority", func() {
739 var client *redis.Client
740 var opt *redis.Options
741 var recorder *commandRecorder
742
743 BeforeEach(func() {
744 recorder = newCommandRecorder(10)
745 })
746
747 AfterEach(func() {
748 if client != nil {
749 Expect(client.Close()).NotTo(HaveOccurred())
750 }
751 })
752
753 It("should use streaming provider when available", func() {
754 streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
755 ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
756 providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
757
758 opt = &redis.Options{
759 Username: "field_user",
760 Password: "field_pass",
761 CredentialsProvider: func() (string, string) {
762 username, password := providerCreds.BasicAuth()
763 return username, password
764 },
765 CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
766 username, password := ctxCreds.BasicAuth()
767 return username, password, nil
768 },
769 StreamingCredentialsProvider: &mockStreamingProvider{
770 credentials: streamingCreds,
771 updates: make(chan auth.Credentials, 1),
772 },
773 }
774
775 client = redis.NewClient(opt)
776 client.AddHook(recorder.Hook())
777
778 Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
779 Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
780 })
781
782 It("should use context provider when streaming provider is not available", func() {
783 ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
784 providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
785
786 opt = &redis.Options{
787 Username: "field_user",
788 Password: "field_pass",
789 CredentialsProvider: func() (string, string) {
790 username, password := providerCreds.BasicAuth()
791 return username, password
792 },
793 CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
794 username, password := ctxCreds.BasicAuth()
795 return username, password, nil
796 },
797 }
798
799 client = redis.NewClient(opt)
800 client.AddHook(recorder.Hook())
801
802 Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
803 Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
804 })
805
806 It("should use regular provider when streaming and context providers are not available", func() {
807 providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
808
809 opt = &redis.Options{
810 Username: "field_user",
811 Password: "field_pass",
812 CredentialsProvider: func() (string, string) {
813 username, password := providerCreds.BasicAuth()
814 return username, password
815 },
816 }
817
818 client = redis.NewClient(opt)
819 client.AddHook(recorder.Hook())
820
821 Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
822 Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
823 })
824
825 It("should use username/password fields when no providers are set", func() {
826 opt = &redis.Options{
827 Username: "field_user",
828 Password: "field_pass",
829 }
830
831 client = redis.NewClient(opt)
832 client.AddHook(recorder.Hook())
833
834 Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
835 Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
836 })
837
838 It("should use empty credentials when nothing is set", func() {
839 opt = &redis.Options{}
840
841 client = redis.NewClient(opt)
842 client.AddHook(recorder.Hook())
843
844 Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
845 Expect(recorder.Contains("AUTH")).To(BeFalse())
846 })
847
848 It("should handle credential updates from streaming provider", func() {
849 initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
850 updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
851 updatesChan := make(chan auth.Credentials, 1)
852
853 opt = &redis.Options{
854 StreamingCredentialsProvider: &mockStreamingProvider{
855 credentials: initialCreds,
856 updates: updatesChan,
857 },
858 }
859
860 client = redis.NewClient(opt)
861 client.AddHook(recorder.Hook())
862
863 Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
864 Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
865
866
867 opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
868
869 Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
870 Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
871 close(updatesChan)
872 })
873 })
874
875 type mockStreamingProvider struct {
876 credentials auth.Credentials
877 err error
878 updates chan auth.Credentials
879 }
880
881 func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
882 if m.err != nil {
883 return nil, nil, m.err
884 }
885
886
887 go func() {
888 for creds := range m.updates {
889 m.credentials = creds
890 listener.OnNext(creds)
891 }
892 }()
893
894 return m.credentials, func() (err error) {
895 defer func() {
896 if r := recover(); r != nil {
897
898
899 }
900 }()
901 return
902 }, nil
903 }
904
905 var _ = Describe("Client creation", func() {
906 Context("simple client with nil options", func() {
907 It("panics", func() {
908 Expect(func() {
909 redis.NewClient(nil)
910 }).To(Panic())
911 })
912 })
913 Context("cluster client with nil options", func() {
914 It("panics", func() {
915 Expect(func() {
916 redis.NewClusterClient(nil)
917 }).To(Panic())
918 })
919 })
920 Context("ring client with nil options", func() {
921 It("panics", func() {
922 Expect(func() {
923 redis.NewRing(nil)
924 }).To(Panic())
925 })
926 })
927 Context("universal client with nil options", func() {
928 It("panics", func() {
929 Expect(func() {
930 redis.NewUniversalClient(nil)
931 }).To(Panic())
932 })
933 })
934 Context("failover client with nil options", func() {
935 It("panics", func() {
936 Expect(func() {
937 redis.NewFailoverClient(nil)
938 }).To(Panic())
939 })
940 })
941 Context("failover cluster client with nil options", func() {
942 It("panics", func() {
943 Expect(func() {
944 redis.NewFailoverClusterClient(nil)
945 }).To(Panic())
946 })
947 })
948 Context("sentinel client with nil options", func() {
949 It("panics", func() {
950 Expect(func() {
951 redis.NewSentinelClient(nil)
952 }).To(Panic())
953 })
954 })
955 })
956
View as plain text