1 package redis
2
3 import (
4 "context"
5 "crypto/tls"
6 "errors"
7 "fmt"
8 "net"
9 "strconv"
10 "sync"
11 "sync/atomic"
12 "time"
13
14 "github.com/cespare/xxhash/v2"
15 "github.com/dgryski/go-rendezvous"
16 "github.com/redis/go-redis/v9/auth"
17
18 "github.com/redis/go-redis/v9/internal"
19 "github.com/redis/go-redis/v9/internal/hashtag"
20 "github.com/redis/go-redis/v9/internal/pool"
21 "github.com/redis/go-redis/v9/internal/proto"
22 "github.com/redis/go-redis/v9/internal/rand"
23 )
24
25 var errRingShardsDown = errors.New("redis: all ring shards are down")
26
27
28 var defaultHeartbeatFn = func(ctx context.Context, client *Client) bool {
29 err := client.Ping(ctx).Err()
30 return err == nil || err == pool.ErrPoolTimeout
31 }
32
33
34
35 type ConsistentHash interface {
36 Get(string) string
37 }
38
39 type rendezvousWrapper struct {
40 *rendezvous.Rendezvous
41 }
42
43 func (w rendezvousWrapper) Get(key string) string {
44 return w.Lookup(key)
45 }
46
47 func newRendezvous(shards []string) ConsistentHash {
48 return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)}
49 }
50
51
52
53
54
55 type RingOptions struct {
56
57 Addrs map[string]string
58
59
60 NewClient func(opt *Options) *Client
61
62
63 ClientName string
64
65
66
67 HeartbeatFrequency time.Duration
68
69
70
71 HeartbeatFn func(ctx context.Context, client *Client) bool
72
73
74
75
76
77
78 NewConsistentHash func(shards []string) ConsistentHash
79
80
81
82 Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
83 OnConnect func(ctx context.Context, cn *Conn) error
84
85 Protocol int
86 Username string
87 Password string
88
89
90 CredentialsProvider func() (username string, password string)
91
92
93
94
95
96 CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
97
98
99
100
101
102
103
104 StreamingCredentialsProvider auth.StreamingCredentialsProvider
105 DB int
106
107 MaxRetries int
108 MinRetryBackoff time.Duration
109 MaxRetryBackoff time.Duration
110
111 DialTimeout time.Duration
112 ReadTimeout time.Duration
113 WriteTimeout time.Duration
114 ContextTimeoutEnabled bool
115
116
117 PoolFIFO bool
118
119 PoolSize int
120 PoolTimeout time.Duration
121 MinIdleConns int
122 MaxIdleConns int
123 MaxActiveConns int
124 ConnMaxIdleTime time.Duration
125 ConnMaxLifetime time.Duration
126
127
128
129
130
131
132 ReadBufferSize int
133
134
135
136
137
138
139 WriteBufferSize int
140
141 TLSConfig *tls.Config
142 Limiter Limiter
143
144
145
146
147
148
149 DisableIndentity bool
150
151
152
153
154 DisableIdentity bool
155 IdentitySuffix string
156 UnstableResp3 bool
157 }
158
159 func (opt *RingOptions) init() {
160 if opt.NewClient == nil {
161 opt.NewClient = func(opt *Options) *Client {
162 return NewClient(opt)
163 }
164 }
165
166 if opt.HeartbeatFrequency == 0 {
167 opt.HeartbeatFrequency = 500 * time.Millisecond
168 }
169
170 if opt.HeartbeatFn == nil {
171 opt.HeartbeatFn = defaultHeartbeatFn
172 }
173
174 if opt.NewConsistentHash == nil {
175 opt.NewConsistentHash = newRendezvous
176 }
177
178 switch opt.MaxRetries {
179 case -1:
180 opt.MaxRetries = 0
181 case 0:
182 opt.MaxRetries = 3
183 }
184 switch opt.MinRetryBackoff {
185 case -1:
186 opt.MinRetryBackoff = 0
187 case 0:
188 opt.MinRetryBackoff = 8 * time.Millisecond
189 }
190 switch opt.MaxRetryBackoff {
191 case -1:
192 opt.MaxRetryBackoff = 0
193 case 0:
194 opt.MaxRetryBackoff = 512 * time.Millisecond
195 }
196
197 if opt.ReadBufferSize == 0 {
198 opt.ReadBufferSize = proto.DefaultBufferSize
199 }
200 if opt.WriteBufferSize == 0 {
201 opt.WriteBufferSize = proto.DefaultBufferSize
202 }
203 }
204
205 func (opt *RingOptions) clientOptions() *Options {
206 return &Options{
207 ClientName: opt.ClientName,
208 Dialer: opt.Dialer,
209 OnConnect: opt.OnConnect,
210
211 Protocol: opt.Protocol,
212 Username: opt.Username,
213 Password: opt.Password,
214 CredentialsProvider: opt.CredentialsProvider,
215 CredentialsProviderContext: opt.CredentialsProviderContext,
216 StreamingCredentialsProvider: opt.StreamingCredentialsProvider,
217 DB: opt.DB,
218
219 MaxRetries: -1,
220
221 DialTimeout: opt.DialTimeout,
222 ReadTimeout: opt.ReadTimeout,
223 WriteTimeout: opt.WriteTimeout,
224 ContextTimeoutEnabled: opt.ContextTimeoutEnabled,
225
226 PoolFIFO: opt.PoolFIFO,
227 PoolSize: opt.PoolSize,
228 PoolTimeout: opt.PoolTimeout,
229 MinIdleConns: opt.MinIdleConns,
230 MaxIdleConns: opt.MaxIdleConns,
231 MaxActiveConns: opt.MaxActiveConns,
232 ConnMaxIdleTime: opt.ConnMaxIdleTime,
233 ConnMaxLifetime: opt.ConnMaxLifetime,
234 ReadBufferSize: opt.ReadBufferSize,
235 WriteBufferSize: opt.WriteBufferSize,
236
237 TLSConfig: opt.TLSConfig,
238 Limiter: opt.Limiter,
239
240 DisableIdentity: opt.DisableIdentity,
241 DisableIndentity: opt.DisableIndentity,
242
243 IdentitySuffix: opt.IdentitySuffix,
244 UnstableResp3: opt.UnstableResp3,
245 }
246 }
247
248
249
250 type ringShard struct {
251 Client *Client
252 down int32
253 addr string
254 }
255
256 func newRingShard(opt *RingOptions, addr string) *ringShard {
257 clopt := opt.clientOptions()
258 clopt.Addr = addr
259
260 return &ringShard{
261 Client: opt.NewClient(clopt),
262 addr: addr,
263 }
264 }
265
266 func (shard *ringShard) String() string {
267 var state string
268 if shard.IsUp() {
269 state = "up"
270 } else {
271 state = "down"
272 }
273 return fmt.Sprintf("%s is %s", shard.Client, state)
274 }
275
276 func (shard *ringShard) IsDown() bool {
277 const threshold = 3
278 return atomic.LoadInt32(&shard.down) >= threshold
279 }
280
281 func (shard *ringShard) IsUp() bool {
282 return !shard.IsDown()
283 }
284
285
286 func (shard *ringShard) Vote(up bool) bool {
287 if up {
288 changed := shard.IsDown()
289 atomic.StoreInt32(&shard.down, 0)
290 return changed
291 }
292
293 if shard.IsDown() {
294 return false
295 }
296
297 atomic.AddInt32(&shard.down, 1)
298 return shard.IsDown()
299 }
300
301
302
303 type ringSharding struct {
304 opt *RingOptions
305
306 mu sync.RWMutex
307 shards *ringShards
308 closed bool
309 hash ConsistentHash
310 numShard int
311 onNewNode []func(rdb *Client)
312
313
314
315 setAddrsMu sync.Mutex
316 }
317
318 type ringShards struct {
319 m map[string]*ringShard
320 list []*ringShard
321 }
322
323 func newRingSharding(opt *RingOptions) *ringSharding {
324 c := &ringSharding{
325 opt: opt,
326 }
327 c.SetAddrs(opt.Addrs)
328
329 return c
330 }
331
332 func (c *ringSharding) OnNewNode(fn func(rdb *Client)) {
333 c.mu.Lock()
334 c.onNewNode = append(c.onNewNode, fn)
335 c.mu.Unlock()
336 }
337
338
339
340
341 func (c *ringSharding) SetAddrs(addrs map[string]string) {
342 c.setAddrsMu.Lock()
343 defer c.setAddrsMu.Unlock()
344
345 cleanup := func(shards map[string]*ringShard) {
346 for addr, shard := range shards {
347 if err := shard.Client.Close(); err != nil {
348 internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err)
349 }
350 }
351 }
352
353 c.mu.RLock()
354 if c.closed {
355 c.mu.RUnlock()
356 return
357 }
358 existing := c.shards
359 c.mu.RUnlock()
360
361 shards, created, unused := c.newRingShards(addrs, existing)
362
363 c.mu.Lock()
364 if c.closed {
365 cleanup(created)
366 c.mu.Unlock()
367 return
368 }
369 c.shards = shards
370 c.rebalanceLocked()
371 c.mu.Unlock()
372
373 cleanup(unused)
374 }
375
376 func (c *ringSharding) newRingShards(
377 addrs map[string]string, existing *ringShards,
378 ) (shards *ringShards, created, unused map[string]*ringShard) {
379 shards = &ringShards{m: make(map[string]*ringShard, len(addrs))}
380 created = make(map[string]*ringShard)
381 unused = make(map[string]*ringShard)
382
383 if existing != nil {
384 for _, shard := range existing.list {
385 unused[shard.addr] = shard
386 }
387 }
388
389 for name, addr := range addrs {
390 if shard, ok := unused[addr]; ok {
391 shards.m[name] = shard
392 delete(unused, addr)
393 } else {
394 shard := newRingShard(c.opt, addr)
395 shards.m[name] = shard
396 created[addr] = shard
397
398 for _, fn := range c.onNewNode {
399 fn(shard.Client)
400 }
401 }
402 }
403
404 for _, shard := range shards.m {
405 shards.list = append(shards.list, shard)
406 }
407
408 return
409 }
410
411
412
413 func (c *ringSharding) List() []*ringShard {
414 c.mu.RLock()
415 defer c.mu.RUnlock()
416
417 if c.closed {
418 return nil
419 }
420 return c.shards.list
421 }
422
423 func (c *ringSharding) Hash(key string) string {
424 key = hashtag.Key(key)
425
426 var hash string
427
428 c.mu.RLock()
429 defer c.mu.RUnlock()
430
431 if c.numShard > 0 {
432 hash = c.hash.Get(key)
433 }
434
435 return hash
436 }
437
438 func (c *ringSharding) GetByKey(key string) (*ringShard, error) {
439 key = hashtag.Key(key)
440
441 c.mu.RLock()
442 defer c.mu.RUnlock()
443
444 if c.closed {
445 return nil, pool.ErrClosed
446 }
447
448 if c.numShard == 0 {
449 return nil, errRingShardsDown
450 }
451
452 shardName := c.hash.Get(key)
453 if shardName == "" {
454 return nil, errRingShardsDown
455 }
456 return c.shards.m[shardName], nil
457 }
458
459 func (c *ringSharding) GetByName(shardName string) (*ringShard, error) {
460 if shardName == "" {
461 return c.Random()
462 }
463
464 c.mu.RLock()
465 defer c.mu.RUnlock()
466
467 shard, ok := c.shards.m[shardName]
468 if !ok {
469 return nil, errors.New("redis: the shard is not in the ring")
470 }
471
472 return shard, nil
473 }
474
475 func (c *ringSharding) Random() (*ringShard, error) {
476 return c.GetByKey(strconv.Itoa(rand.Int()))
477 }
478
479
480 func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) {
481 ticker := time.NewTicker(frequency)
482 defer ticker.Stop()
483
484 for {
485 select {
486 case <-ticker.C:
487 var rebalance bool
488
489
490 for _, shard := range c.List() {
491 isUp := c.opt.HeartbeatFn(ctx, shard.Client)
492 if shard.Vote(isUp) {
493 internal.Logger.Printf(ctx, "ring shard state changed: %s", shard)
494 rebalance = true
495 }
496 }
497
498 if rebalance {
499 c.mu.Lock()
500 c.rebalanceLocked()
501 c.mu.Unlock()
502 }
503 case <-ctx.Done():
504 return
505 }
506 }
507 }
508
509
510
511 func (c *ringSharding) rebalanceLocked() {
512 if c.closed {
513 return
514 }
515 if c.shards == nil {
516 return
517 }
518
519 liveShards := make([]string, 0, len(c.shards.m))
520
521 for name, shard := range c.shards.m {
522 if shard.IsUp() {
523 liveShards = append(liveShards, name)
524 }
525 }
526
527 c.hash = c.opt.NewConsistentHash(liveShards)
528 c.numShard = len(liveShards)
529 }
530
531 func (c *ringSharding) Len() int {
532 c.mu.RLock()
533 defer c.mu.RUnlock()
534
535 return c.numShard
536 }
537
538 func (c *ringSharding) Close() error {
539 c.mu.Lock()
540 defer c.mu.Unlock()
541
542 if c.closed {
543 return nil
544 }
545 c.closed = true
546
547 var firstErr error
548
549 for _, shard := range c.shards.list {
550 if err := shard.Client.Close(); err != nil && firstErr == nil {
551 firstErr = err
552 }
553 }
554
555 c.hash = nil
556 c.shards = nil
557 c.numShard = 0
558
559 return firstErr
560 }
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578 type Ring struct {
579 cmdable
580 hooksMixin
581
582 opt *RingOptions
583 sharding *ringSharding
584 cmdsInfoCache *cmdsInfoCache
585 heartbeatCancelFn context.CancelFunc
586 }
587
588 func NewRing(opt *RingOptions) *Ring {
589 if opt == nil {
590 panic("redis: NewRing nil options")
591 }
592 opt.init()
593
594 hbCtx, hbCancel := context.WithCancel(context.Background())
595
596 ring := Ring{
597 opt: opt,
598 sharding: newRingSharding(opt),
599 heartbeatCancelFn: hbCancel,
600 }
601
602 ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
603 ring.cmdable = ring.Process
604
605 ring.initHooks(hooks{
606 process: ring.process,
607 pipeline: func(ctx context.Context, cmds []Cmder) error {
608 return ring.generalProcessPipeline(ctx, cmds, false)
609 },
610 txPipeline: func(ctx context.Context, cmds []Cmder) error {
611 return ring.generalProcessPipeline(ctx, cmds, true)
612 },
613 })
614
615 go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency)
616
617 return &ring
618 }
619
620 func (c *Ring) SetAddrs(addrs map[string]string) {
621 c.sharding.SetAddrs(addrs)
622 }
623
624 func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
625 err := c.processHook(ctx, cmd)
626 cmd.SetErr(err)
627 return err
628 }
629
630
631 func (c *Ring) Options() *RingOptions {
632 return c.opt
633 }
634
635 func (c *Ring) retryBackoff(attempt int) time.Duration {
636 return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
637 }
638
639
640 func (c *Ring) PoolStats() *PoolStats {
641
642 shards := c.sharding.List()
643 var acc PoolStats
644 for _, shard := range shards {
645 s := shard.Client.connPool.Stats()
646 acc.Hits += s.Hits
647 acc.Misses += s.Misses
648 acc.Timeouts += s.Timeouts
649 acc.TotalConns += s.TotalConns
650 acc.IdleConns += s.IdleConns
651 }
652 return &acc
653 }
654
655
656 func (c *Ring) Len() int {
657 return c.sharding.Len()
658 }
659
660
661 func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
662 if len(channels) == 0 {
663 panic("at least one channel is required")
664 }
665
666 shard, err := c.sharding.GetByKey(channels[0])
667 if err != nil {
668
669 panic(err)
670 }
671 return shard.Client.Subscribe(ctx, channels...)
672 }
673
674
675 func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
676 if len(channels) == 0 {
677 panic("at least one channel is required")
678 }
679
680 shard, err := c.sharding.GetByKey(channels[0])
681 if err != nil {
682
683 panic(err)
684 }
685 return shard.Client.PSubscribe(ctx, channels...)
686 }
687
688
689 func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub {
690 if len(channels) == 0 {
691 panic("at least one channel is required")
692 }
693 shard, err := c.sharding.GetByKey(channels[0])
694 if err != nil {
695
696 panic(err)
697 }
698 return shard.Client.SSubscribe(ctx, channels...)
699 }
700
701 func (c *Ring) OnNewNode(fn func(rdb *Client)) {
702 c.sharding.OnNewNode(fn)
703 }
704
705
706
707 func (c *Ring) ForEachShard(
708 ctx context.Context,
709 fn func(ctx context.Context, client *Client) error,
710 ) error {
711
712 shards := c.sharding.List()
713 var wg sync.WaitGroup
714 errCh := make(chan error, 1)
715 for _, shard := range shards {
716 if shard.IsDown() {
717 continue
718 }
719
720 wg.Add(1)
721 go func(shard *ringShard) {
722 defer wg.Done()
723 err := fn(ctx, shard.Client)
724 if err != nil {
725 select {
726 case errCh <- err:
727 default:
728 }
729 }
730 }(shard)
731 }
732 wg.Wait()
733
734 select {
735 case err := <-errCh:
736 return err
737 default:
738 return nil
739 }
740 }
741
742 func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) {
743
744 shards := c.sharding.List()
745 var firstErr error
746 for _, shard := range shards {
747 cmdsInfo, err := shard.Client.Command(ctx).Result()
748 if err == nil {
749 return cmdsInfo, nil
750 }
751 if firstErr == nil {
752 firstErr = err
753 }
754 }
755 if firstErr == nil {
756 return nil, errRingShardsDown
757 }
758 return nil, firstErr
759 }
760
761 func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
762 pos := cmdFirstKeyPos(cmd)
763 if pos == 0 {
764 return c.sharding.Random()
765 }
766 firstKey := cmd.stringArg(pos)
767 return c.sharding.GetByKey(firstKey)
768 }
769
770 func (c *Ring) process(ctx context.Context, cmd Cmder) error {
771 var lastErr error
772 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
773 if attempt > 0 {
774 if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
775 return err
776 }
777 }
778
779 shard, err := c.cmdShard(cmd)
780 if err != nil {
781 return err
782 }
783
784 lastErr = shard.Client.Process(ctx, cmd)
785 if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) {
786 return lastErr
787 }
788 }
789 return lastErr
790 }
791
792 func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
793 return c.Pipeline().Pipelined(ctx, fn)
794 }
795
796 func (c *Ring) Pipeline() Pipeliner {
797 pipe := Pipeline{
798 exec: pipelineExecer(c.processPipelineHook),
799 }
800 pipe.init()
801 return &pipe
802 }
803
804 func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
805 return c.TxPipeline().Pipelined(ctx, fn)
806 }
807
808 func (c *Ring) TxPipeline() Pipeliner {
809 pipe := Pipeline{
810 exec: func(ctx context.Context, cmds []Cmder) error {
811 cmds = wrapMultiExec(ctx, cmds)
812 return c.processTxPipelineHook(ctx, cmds)
813 },
814 }
815 pipe.init()
816 return &pipe
817 }
818
819 func (c *Ring) generalProcessPipeline(
820 ctx context.Context, cmds []Cmder, tx bool,
821 ) error {
822 if tx {
823
824 cmds = cmds[1 : len(cmds)-1]
825 }
826
827 cmdsMap := make(map[string][]Cmder)
828
829 for _, cmd := range cmds {
830 hash := cmd.stringArg(cmdFirstKeyPos(cmd))
831 if hash != "" {
832 hash = c.sharding.Hash(hash)
833 }
834 cmdsMap[hash] = append(cmdsMap[hash], cmd)
835 }
836
837 var wg sync.WaitGroup
838 errs := make(chan error, len(cmdsMap))
839
840 for hash, cmds := range cmdsMap {
841 wg.Add(1)
842 go func(hash string, cmds []Cmder) {
843 defer wg.Done()
844
845
846 shard, err := c.sharding.GetByName(hash)
847 if err != nil {
848 setCmdsErr(cmds, err)
849 return
850 }
851
852 hook := shard.Client.processPipelineHook
853 if tx {
854 cmds = wrapMultiExec(ctx, cmds)
855 hook = shard.Client.processTxPipelineHook
856 }
857
858 if err = hook(ctx, cmds); err != nil {
859 errs <- err
860 }
861 }(hash, cmds)
862 }
863
864 wg.Wait()
865 close(errs)
866
867 if err := <-errs; err != nil {
868 return err
869 }
870 return cmdsFirstErr(cmds)
871 }
872
873 func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
874 if len(keys) == 0 {
875 return fmt.Errorf("redis: Watch requires at least one key")
876 }
877
878 var shards []*ringShard
879
880 for _, key := range keys {
881 if key != "" {
882 shard, err := c.sharding.GetByKey(key)
883 if err != nil {
884 return err
885 }
886
887 shards = append(shards, shard)
888 }
889 }
890
891 if len(shards) == 0 {
892 return fmt.Errorf("redis: Watch requires at least one shard")
893 }
894
895 if len(shards) > 1 {
896 for _, shard := range shards[1:] {
897 if shard.Client != shards[0].Client {
898 err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
899 return err
900 }
901 }
902 }
903
904 return shards[0].Client.Watch(ctx, fn, keys...)
905 }
906
907
908
909
910
911 func (c *Ring) Close() error {
912 c.heartbeatCancelFn()
913
914 return c.sharding.Close()
915 }
916
917
918
919 func (c *Ring) GetShardClients() []*Client {
920 shards := c.sharding.List()
921 clients := make([]*Client, 0, len(shards))
922 for _, shard := range shards {
923 if shard.IsUp() {
924 clients = append(clients, shard.Client)
925 }
926 }
927 return clients
928 }
929
930
931
932 func (c *Ring) GetShardClientForKey(key string) (*Client, error) {
933 shard, err := c.sharding.GetByKey(key)
934 if err != nil {
935 return nil, err
936 }
937 return shard.Client, nil
938 }
939
View as plain text