...

Source file src/github.com/redis/go-redis/v9/ring.go

Documentation: github.com/redis/go-redis/v9

     1  package redis
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"strconv"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/cespare/xxhash/v2"
    15  	"github.com/dgryski/go-rendezvous" //nolint
    16  	"github.com/redis/go-redis/v9/auth"
    17  
    18  	"github.com/redis/go-redis/v9/internal"
    19  	"github.com/redis/go-redis/v9/internal/hashtag"
    20  	"github.com/redis/go-redis/v9/internal/pool"
    21  	"github.com/redis/go-redis/v9/internal/proto"
    22  	"github.com/redis/go-redis/v9/internal/rand"
    23  )
    24  
    25  var errRingShardsDown = errors.New("redis: all ring shards are down")
    26  
    27  // defaultHeartbeatFn is the default function used to check the shard liveness
    28  var defaultHeartbeatFn = func(ctx context.Context, client *Client) bool {
    29  	err := client.Ping(ctx).Err()
    30  	return err == nil || err == pool.ErrPoolTimeout
    31  }
    32  
    33  //------------------------------------------------------------------------------
    34  
    35  type ConsistentHash interface {
    36  	Get(string) string
    37  }
    38  
    39  type rendezvousWrapper struct {
    40  	*rendezvous.Rendezvous
    41  }
    42  
    43  func (w rendezvousWrapper) Get(key string) string {
    44  	return w.Lookup(key)
    45  }
    46  
    47  func newRendezvous(shards []string) ConsistentHash {
    48  	return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)}
    49  }
    50  
    51  //------------------------------------------------------------------------------
    52  
    53  // RingOptions are used to configure a ring client and should be
    54  // passed to NewRing.
    55  type RingOptions struct {
    56  	// Map of name => host:port addresses of ring shards.
    57  	Addrs map[string]string
    58  
    59  	// NewClient creates a shard client with provided options.
    60  	NewClient func(opt *Options) *Client
    61  
    62  	// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
    63  	ClientName string
    64  
    65  	// Frequency of executing HeartbeatFn to check shards availability.
    66  	// Shard is considered down after 3 subsequent failed checks.
    67  	HeartbeatFrequency time.Duration
    68  
    69  	// A function used to check the shard liveness
    70  	// if not set, defaults to defaultHeartbeatFn
    71  	HeartbeatFn func(ctx context.Context, client *Client) bool
    72  
    73  	// NewConsistentHash returns a consistent hash that is used
    74  	// to distribute keys across the shards.
    75  	//
    76  	// See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8
    77  	// for consistent hashing algorithmic tradeoffs.
    78  	NewConsistentHash func(shards []string) ConsistentHash
    79  
    80  	// Following options are copied from Options struct.
    81  
    82  	Dialer    func(ctx context.Context, network, addr string) (net.Conn, error)
    83  	OnConnect func(ctx context.Context, cn *Conn) error
    84  
    85  	Protocol int
    86  	Username string
    87  	Password string
    88  	// CredentialsProvider allows the username and password to be updated
    89  	// before reconnecting. It should return the current username and password.
    90  	CredentialsProvider func() (username string, password string)
    91  
    92  	// CredentialsProviderContext is an enhanced parameter of CredentialsProvider,
    93  	// done to maintain API compatibility. In the future,
    94  	// there might be a merge between CredentialsProviderContext and CredentialsProvider.
    95  	// There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider.
    96  	CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
    97  
    98  	// StreamingCredentialsProvider is used to retrieve the credentials
    99  	// for the connection from an external source. Those credentials may change
   100  	// during the connection lifetime. This is useful for managed identity
   101  	// scenarios where the credentials are retrieved from an external source.
   102  	//
   103  	// Currently, this is a placeholder for the future implementation.
   104  	StreamingCredentialsProvider auth.StreamingCredentialsProvider
   105  	DB                           int
   106  
   107  	MaxRetries      int
   108  	MinRetryBackoff time.Duration
   109  	MaxRetryBackoff time.Duration
   110  
   111  	DialTimeout           time.Duration
   112  	ReadTimeout           time.Duration
   113  	WriteTimeout          time.Duration
   114  	ContextTimeoutEnabled bool
   115  
   116  	// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO).
   117  	PoolFIFO bool
   118  
   119  	PoolSize        int
   120  	PoolTimeout     time.Duration
   121  	MinIdleConns    int
   122  	MaxIdleConns    int
   123  	MaxActiveConns  int
   124  	ConnMaxIdleTime time.Duration
   125  	ConnMaxLifetime time.Duration
   126  
   127  	// ReadBufferSize is the size of the bufio.Reader buffer for each connection.
   128  	// Larger buffers can improve performance for commands that return large responses.
   129  	// Smaller buffers can improve memory usage for larger pools.
   130  	//
   131  	// default: 32KiB (32768 bytes)
   132  	ReadBufferSize int
   133  
   134  	// WriteBufferSize is the size of the bufio.Writer buffer for each connection.
   135  	// Larger buffers can improve performance for large pipelines and commands with many arguments.
   136  	// Smaller buffers can improve memory usage for larger pools.
   137  	//
   138  	// default: 32KiB (32768 bytes)
   139  	WriteBufferSize int
   140  
   141  	TLSConfig *tls.Config
   142  	Limiter   Limiter
   143  
   144  	// DisableIndentity - Disable set-lib on connect.
   145  	//
   146  	// default: false
   147  	//
   148  	// Deprecated: Use DisableIdentity instead.
   149  	DisableIndentity bool
   150  
   151  	// DisableIdentity is used to disable CLIENT SETINFO command on connect.
   152  	//
   153  	// default: false
   154  	DisableIdentity bool
   155  	IdentitySuffix  string
   156  	UnstableResp3   bool
   157  }
   158  
   159  func (opt *RingOptions) init() {
   160  	if opt.NewClient == nil {
   161  		opt.NewClient = func(opt *Options) *Client {
   162  			return NewClient(opt)
   163  		}
   164  	}
   165  
   166  	if opt.HeartbeatFrequency == 0 {
   167  		opt.HeartbeatFrequency = 500 * time.Millisecond
   168  	}
   169  
   170  	if opt.HeartbeatFn == nil {
   171  		opt.HeartbeatFn = defaultHeartbeatFn
   172  	}
   173  
   174  	if opt.NewConsistentHash == nil {
   175  		opt.NewConsistentHash = newRendezvous
   176  	}
   177  
   178  	switch opt.MaxRetries {
   179  	case -1:
   180  		opt.MaxRetries = 0
   181  	case 0:
   182  		opt.MaxRetries = 3
   183  	}
   184  	switch opt.MinRetryBackoff {
   185  	case -1:
   186  		opt.MinRetryBackoff = 0
   187  	case 0:
   188  		opt.MinRetryBackoff = 8 * time.Millisecond
   189  	}
   190  	switch opt.MaxRetryBackoff {
   191  	case -1:
   192  		opt.MaxRetryBackoff = 0
   193  	case 0:
   194  		opt.MaxRetryBackoff = 512 * time.Millisecond
   195  	}
   196  
   197  	if opt.ReadBufferSize == 0 {
   198  		opt.ReadBufferSize = proto.DefaultBufferSize
   199  	}
   200  	if opt.WriteBufferSize == 0 {
   201  		opt.WriteBufferSize = proto.DefaultBufferSize
   202  	}
   203  }
   204  
   205  func (opt *RingOptions) clientOptions() *Options {
   206  	return &Options{
   207  		ClientName: opt.ClientName,
   208  		Dialer:     opt.Dialer,
   209  		OnConnect:  opt.OnConnect,
   210  
   211  		Protocol:                     opt.Protocol,
   212  		Username:                     opt.Username,
   213  		Password:                     opt.Password,
   214  		CredentialsProvider:          opt.CredentialsProvider,
   215  		CredentialsProviderContext:   opt.CredentialsProviderContext,
   216  		StreamingCredentialsProvider: opt.StreamingCredentialsProvider,
   217  		DB:                           opt.DB,
   218  
   219  		MaxRetries: -1,
   220  
   221  		DialTimeout:           opt.DialTimeout,
   222  		ReadTimeout:           opt.ReadTimeout,
   223  		WriteTimeout:          opt.WriteTimeout,
   224  		ContextTimeoutEnabled: opt.ContextTimeoutEnabled,
   225  
   226  		PoolFIFO:        opt.PoolFIFO,
   227  		PoolSize:        opt.PoolSize,
   228  		PoolTimeout:     opt.PoolTimeout,
   229  		MinIdleConns:    opt.MinIdleConns,
   230  		MaxIdleConns:    opt.MaxIdleConns,
   231  		MaxActiveConns:  opt.MaxActiveConns,
   232  		ConnMaxIdleTime: opt.ConnMaxIdleTime,
   233  		ConnMaxLifetime: opt.ConnMaxLifetime,
   234  		ReadBufferSize:  opt.ReadBufferSize,
   235  		WriteBufferSize: opt.WriteBufferSize,
   236  
   237  		TLSConfig: opt.TLSConfig,
   238  		Limiter:   opt.Limiter,
   239  
   240  		DisableIdentity:  opt.DisableIdentity,
   241  		DisableIndentity: opt.DisableIndentity,
   242  
   243  		IdentitySuffix: opt.IdentitySuffix,
   244  		UnstableResp3:  opt.UnstableResp3,
   245  	}
   246  }
   247  
   248  //------------------------------------------------------------------------------
   249  
   250  type ringShard struct {
   251  	Client *Client
   252  	down   int32
   253  	addr   string
   254  }
   255  
   256  func newRingShard(opt *RingOptions, addr string) *ringShard {
   257  	clopt := opt.clientOptions()
   258  	clopt.Addr = addr
   259  
   260  	return &ringShard{
   261  		Client: opt.NewClient(clopt),
   262  		addr:   addr,
   263  	}
   264  }
   265  
   266  func (shard *ringShard) String() string {
   267  	var state string
   268  	if shard.IsUp() {
   269  		state = "up"
   270  	} else {
   271  		state = "down"
   272  	}
   273  	return fmt.Sprintf("%s is %s", shard.Client, state)
   274  }
   275  
   276  func (shard *ringShard) IsDown() bool {
   277  	const threshold = 3
   278  	return atomic.LoadInt32(&shard.down) >= threshold
   279  }
   280  
   281  func (shard *ringShard) IsUp() bool {
   282  	return !shard.IsDown()
   283  }
   284  
   285  // Vote votes to set shard state and returns true if state was changed.
   286  func (shard *ringShard) Vote(up bool) bool {
   287  	if up {
   288  		changed := shard.IsDown()
   289  		atomic.StoreInt32(&shard.down, 0)
   290  		return changed
   291  	}
   292  
   293  	if shard.IsDown() {
   294  		return false
   295  	}
   296  
   297  	atomic.AddInt32(&shard.down, 1)
   298  	return shard.IsDown()
   299  }
   300  
   301  //------------------------------------------------------------------------------
   302  
   303  type ringSharding struct {
   304  	opt *RingOptions
   305  
   306  	mu        sync.RWMutex
   307  	shards    *ringShards
   308  	closed    bool
   309  	hash      ConsistentHash
   310  	numShard  int
   311  	onNewNode []func(rdb *Client)
   312  
   313  	// ensures exclusive access to SetAddrs so there is no need
   314  	// to hold mu for the duration of potentially long shard creation
   315  	setAddrsMu sync.Mutex
   316  }
   317  
   318  type ringShards struct {
   319  	m    map[string]*ringShard
   320  	list []*ringShard
   321  }
   322  
   323  func newRingSharding(opt *RingOptions) *ringSharding {
   324  	c := &ringSharding{
   325  		opt: opt,
   326  	}
   327  	c.SetAddrs(opt.Addrs)
   328  
   329  	return c
   330  }
   331  
   332  func (c *ringSharding) OnNewNode(fn func(rdb *Client)) {
   333  	c.mu.Lock()
   334  	c.onNewNode = append(c.onNewNode, fn)
   335  	c.mu.Unlock()
   336  }
   337  
   338  // SetAddrs replaces the shards in use, such that you can increase and
   339  // decrease number of shards, that you use. It will reuse shards that
   340  // existed before and close the ones that will not be used anymore.
   341  func (c *ringSharding) SetAddrs(addrs map[string]string) {
   342  	c.setAddrsMu.Lock()
   343  	defer c.setAddrsMu.Unlock()
   344  
   345  	cleanup := func(shards map[string]*ringShard) {
   346  		for addr, shard := range shards {
   347  			if err := shard.Client.Close(); err != nil {
   348  				internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err)
   349  			}
   350  		}
   351  	}
   352  
   353  	c.mu.RLock()
   354  	if c.closed {
   355  		c.mu.RUnlock()
   356  		return
   357  	}
   358  	existing := c.shards
   359  	c.mu.RUnlock()
   360  
   361  	shards, created, unused := c.newRingShards(addrs, existing)
   362  
   363  	c.mu.Lock()
   364  	if c.closed {
   365  		cleanup(created)
   366  		c.mu.Unlock()
   367  		return
   368  	}
   369  	c.shards = shards
   370  	c.rebalanceLocked()
   371  	c.mu.Unlock()
   372  
   373  	cleanup(unused)
   374  }
   375  
   376  func (c *ringSharding) newRingShards(
   377  	addrs map[string]string, existing *ringShards,
   378  ) (shards *ringShards, created, unused map[string]*ringShard) {
   379  	shards = &ringShards{m: make(map[string]*ringShard, len(addrs))}
   380  	created = make(map[string]*ringShard) // indexed by addr
   381  	unused = make(map[string]*ringShard)  // indexed by addr
   382  
   383  	if existing != nil {
   384  		for _, shard := range existing.list {
   385  			unused[shard.addr] = shard
   386  		}
   387  	}
   388  
   389  	for name, addr := range addrs {
   390  		if shard, ok := unused[addr]; ok {
   391  			shards.m[name] = shard
   392  			delete(unused, addr)
   393  		} else {
   394  			shard := newRingShard(c.opt, addr)
   395  			shards.m[name] = shard
   396  			created[addr] = shard
   397  
   398  			for _, fn := range c.onNewNode {
   399  				fn(shard.Client)
   400  			}
   401  		}
   402  	}
   403  
   404  	for _, shard := range shards.m {
   405  		shards.list = append(shards.list, shard)
   406  	}
   407  
   408  	return
   409  }
   410  
   411  // Warning: External exposure of `c.shards.list` may cause data races.
   412  // So keep internal or implement deep copy if exposed.
   413  func (c *ringSharding) List() []*ringShard {
   414  	c.mu.RLock()
   415  	defer c.mu.RUnlock()
   416  
   417  	if c.closed {
   418  		return nil
   419  	}
   420  	return c.shards.list
   421  }
   422  
   423  func (c *ringSharding) Hash(key string) string {
   424  	key = hashtag.Key(key)
   425  
   426  	var hash string
   427  
   428  	c.mu.RLock()
   429  	defer c.mu.RUnlock()
   430  
   431  	if c.numShard > 0 {
   432  		hash = c.hash.Get(key)
   433  	}
   434  
   435  	return hash
   436  }
   437  
   438  func (c *ringSharding) GetByKey(key string) (*ringShard, error) {
   439  	key = hashtag.Key(key)
   440  
   441  	c.mu.RLock()
   442  	defer c.mu.RUnlock()
   443  
   444  	if c.closed {
   445  		return nil, pool.ErrClosed
   446  	}
   447  
   448  	if c.numShard == 0 {
   449  		return nil, errRingShardsDown
   450  	}
   451  
   452  	shardName := c.hash.Get(key)
   453  	if shardName == "" {
   454  		return nil, errRingShardsDown
   455  	}
   456  	return c.shards.m[shardName], nil
   457  }
   458  
   459  func (c *ringSharding) GetByName(shardName string) (*ringShard, error) {
   460  	if shardName == "" {
   461  		return c.Random()
   462  	}
   463  
   464  	c.mu.RLock()
   465  	defer c.mu.RUnlock()
   466  
   467  	shard, ok := c.shards.m[shardName]
   468  	if !ok {
   469  		return nil, errors.New("redis: the shard is not in the ring")
   470  	}
   471  
   472  	return shard, nil
   473  }
   474  
   475  func (c *ringSharding) Random() (*ringShard, error) {
   476  	return c.GetByKey(strconv.Itoa(rand.Int()))
   477  }
   478  
   479  // Heartbeat monitors state of each shard in the ring.
   480  func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) {
   481  	ticker := time.NewTicker(frequency)
   482  	defer ticker.Stop()
   483  
   484  	for {
   485  		select {
   486  		case <-ticker.C:
   487  			var rebalance bool
   488  
   489  			// note: `c.List()` return a shadow copy of `[]*ringShard`.
   490  			for _, shard := range c.List() {
   491  				isUp := c.opt.HeartbeatFn(ctx, shard.Client)
   492  				if shard.Vote(isUp) {
   493  					internal.Logger.Printf(ctx, "ring shard state changed: %s", shard)
   494  					rebalance = true
   495  				}
   496  			}
   497  
   498  			if rebalance {
   499  				c.mu.Lock()
   500  				c.rebalanceLocked()
   501  				c.mu.Unlock()
   502  			}
   503  		case <-ctx.Done():
   504  			return
   505  		}
   506  	}
   507  }
   508  
   509  // rebalanceLocked removes dead shards from the Ring.
   510  // Requires c.mu locked.
   511  func (c *ringSharding) rebalanceLocked() {
   512  	if c.closed {
   513  		return
   514  	}
   515  	if c.shards == nil {
   516  		return
   517  	}
   518  
   519  	liveShards := make([]string, 0, len(c.shards.m))
   520  
   521  	for name, shard := range c.shards.m {
   522  		if shard.IsUp() {
   523  			liveShards = append(liveShards, name)
   524  		}
   525  	}
   526  
   527  	c.hash = c.opt.NewConsistentHash(liveShards)
   528  	c.numShard = len(liveShards)
   529  }
   530  
   531  func (c *ringSharding) Len() int {
   532  	c.mu.RLock()
   533  	defer c.mu.RUnlock()
   534  
   535  	return c.numShard
   536  }
   537  
   538  func (c *ringSharding) Close() error {
   539  	c.mu.Lock()
   540  	defer c.mu.Unlock()
   541  
   542  	if c.closed {
   543  		return nil
   544  	}
   545  	c.closed = true
   546  
   547  	var firstErr error
   548  
   549  	for _, shard := range c.shards.list {
   550  		if err := shard.Client.Close(); err != nil && firstErr == nil {
   551  			firstErr = err
   552  		}
   553  	}
   554  
   555  	c.hash = nil
   556  	c.shards = nil
   557  	c.numShard = 0
   558  
   559  	return firstErr
   560  }
   561  
   562  //------------------------------------------------------------------------------
   563  
   564  // Ring is a Redis client that uses consistent hashing to distribute
   565  // keys across multiple Redis servers (shards). It's safe for
   566  // concurrent use by multiple goroutines.
   567  //
   568  // Ring monitors the state of each shard and removes dead shards from
   569  // the ring. When a shard comes online it is added back to the ring. This
   570  // gives you maximum availability and partition tolerance, but no
   571  // consistency between different shards or even clients. Each client
   572  // uses shards that are available to the client and does not do any
   573  // coordination when shard state is changed.
   574  //
   575  // Ring should be used when you need multiple Redis servers for caching
   576  // and can tolerate losing data when one of the servers dies.
   577  // Otherwise you should use Redis Cluster.
   578  type Ring struct {
   579  	cmdable
   580  	hooksMixin
   581  
   582  	opt               *RingOptions
   583  	sharding          *ringSharding
   584  	cmdsInfoCache     *cmdsInfoCache
   585  	heartbeatCancelFn context.CancelFunc
   586  }
   587  
   588  func NewRing(opt *RingOptions) *Ring {
   589  	if opt == nil {
   590  		panic("redis: NewRing nil options")
   591  	}
   592  	opt.init()
   593  
   594  	hbCtx, hbCancel := context.WithCancel(context.Background())
   595  
   596  	ring := Ring{
   597  		opt:               opt,
   598  		sharding:          newRingSharding(opt),
   599  		heartbeatCancelFn: hbCancel,
   600  	}
   601  
   602  	ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
   603  	ring.cmdable = ring.Process
   604  
   605  	ring.initHooks(hooks{
   606  		process: ring.process,
   607  		pipeline: func(ctx context.Context, cmds []Cmder) error {
   608  			return ring.generalProcessPipeline(ctx, cmds, false)
   609  		},
   610  		txPipeline: func(ctx context.Context, cmds []Cmder) error {
   611  			return ring.generalProcessPipeline(ctx, cmds, true)
   612  		},
   613  	})
   614  
   615  	go ring.sharding.Heartbeat(hbCtx, opt.HeartbeatFrequency)
   616  
   617  	return &ring
   618  }
   619  
   620  func (c *Ring) SetAddrs(addrs map[string]string) {
   621  	c.sharding.SetAddrs(addrs)
   622  }
   623  
   624  func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
   625  	err := c.processHook(ctx, cmd)
   626  	cmd.SetErr(err)
   627  	return err
   628  }
   629  
   630  // Options returns read-only Options that were used to create the client.
   631  func (c *Ring) Options() *RingOptions {
   632  	return c.opt
   633  }
   634  
   635  func (c *Ring) retryBackoff(attempt int) time.Duration {
   636  	return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
   637  }
   638  
   639  // PoolStats returns accumulated connection pool stats.
   640  func (c *Ring) PoolStats() *PoolStats {
   641  	// note: `c.List()` return a shadow copy of `[]*ringShard`.
   642  	shards := c.sharding.List()
   643  	var acc PoolStats
   644  	for _, shard := range shards {
   645  		s := shard.Client.connPool.Stats()
   646  		acc.Hits += s.Hits
   647  		acc.Misses += s.Misses
   648  		acc.Timeouts += s.Timeouts
   649  		acc.TotalConns += s.TotalConns
   650  		acc.IdleConns += s.IdleConns
   651  	}
   652  	return &acc
   653  }
   654  
   655  // Len returns the current number of shards in the ring.
   656  func (c *Ring) Len() int {
   657  	return c.sharding.Len()
   658  }
   659  
   660  // Subscribe subscribes the client to the specified channels.
   661  func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
   662  	if len(channels) == 0 {
   663  		panic("at least one channel is required")
   664  	}
   665  
   666  	shard, err := c.sharding.GetByKey(channels[0])
   667  	if err != nil {
   668  		// TODO: return PubSub with sticky error
   669  		panic(err)
   670  	}
   671  	return shard.Client.Subscribe(ctx, channels...)
   672  }
   673  
   674  // PSubscribe subscribes the client to the given patterns.
   675  func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
   676  	if len(channels) == 0 {
   677  		panic("at least one channel is required")
   678  	}
   679  
   680  	shard, err := c.sharding.GetByKey(channels[0])
   681  	if err != nil {
   682  		// TODO: return PubSub with sticky error
   683  		panic(err)
   684  	}
   685  	return shard.Client.PSubscribe(ctx, channels...)
   686  }
   687  
   688  // SSubscribe Subscribes the client to the specified shard channels.
   689  func (c *Ring) SSubscribe(ctx context.Context, channels ...string) *PubSub {
   690  	if len(channels) == 0 {
   691  		panic("at least one channel is required")
   692  	}
   693  	shard, err := c.sharding.GetByKey(channels[0])
   694  	if err != nil {
   695  		// TODO: return PubSub with sticky error
   696  		panic(err)
   697  	}
   698  	return shard.Client.SSubscribe(ctx, channels...)
   699  }
   700  
   701  func (c *Ring) OnNewNode(fn func(rdb *Client)) {
   702  	c.sharding.OnNewNode(fn)
   703  }
   704  
   705  // ForEachShard concurrently calls the fn on each live shard in the ring.
   706  // It returns the first error if any.
   707  func (c *Ring) ForEachShard(
   708  	ctx context.Context,
   709  	fn func(ctx context.Context, client *Client) error,
   710  ) error {
   711  	// note: `c.List()` return a shadow copy of `[]*ringShard`.
   712  	shards := c.sharding.List()
   713  	var wg sync.WaitGroup
   714  	errCh := make(chan error, 1)
   715  	for _, shard := range shards {
   716  		if shard.IsDown() {
   717  			continue
   718  		}
   719  
   720  		wg.Add(1)
   721  		go func(shard *ringShard) {
   722  			defer wg.Done()
   723  			err := fn(ctx, shard.Client)
   724  			if err != nil {
   725  				select {
   726  				case errCh <- err:
   727  				default:
   728  				}
   729  			}
   730  		}(shard)
   731  	}
   732  	wg.Wait()
   733  
   734  	select {
   735  	case err := <-errCh:
   736  		return err
   737  	default:
   738  		return nil
   739  	}
   740  }
   741  
   742  func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) {
   743  	// note: `c.List()` return a shadow copy of `[]*ringShard`.
   744  	shards := c.sharding.List()
   745  	var firstErr error
   746  	for _, shard := range shards {
   747  		cmdsInfo, err := shard.Client.Command(ctx).Result()
   748  		if err == nil {
   749  			return cmdsInfo, nil
   750  		}
   751  		if firstErr == nil {
   752  			firstErr = err
   753  		}
   754  	}
   755  	if firstErr == nil {
   756  		return nil, errRingShardsDown
   757  	}
   758  	return nil, firstErr
   759  }
   760  
   761  func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
   762  	pos := cmdFirstKeyPos(cmd)
   763  	if pos == 0 {
   764  		return c.sharding.Random()
   765  	}
   766  	firstKey := cmd.stringArg(pos)
   767  	return c.sharding.GetByKey(firstKey)
   768  }
   769  
   770  func (c *Ring) process(ctx context.Context, cmd Cmder) error {
   771  	var lastErr error
   772  	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
   773  		if attempt > 0 {
   774  			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
   775  				return err
   776  			}
   777  		}
   778  
   779  		shard, err := c.cmdShard(cmd)
   780  		if err != nil {
   781  			return err
   782  		}
   783  
   784  		lastErr = shard.Client.Process(ctx, cmd)
   785  		if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) {
   786  			return lastErr
   787  		}
   788  	}
   789  	return lastErr
   790  }
   791  
   792  func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
   793  	return c.Pipeline().Pipelined(ctx, fn)
   794  }
   795  
   796  func (c *Ring) Pipeline() Pipeliner {
   797  	pipe := Pipeline{
   798  		exec: pipelineExecer(c.processPipelineHook),
   799  	}
   800  	pipe.init()
   801  	return &pipe
   802  }
   803  
   804  func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
   805  	return c.TxPipeline().Pipelined(ctx, fn)
   806  }
   807  
   808  func (c *Ring) TxPipeline() Pipeliner {
   809  	pipe := Pipeline{
   810  		exec: func(ctx context.Context, cmds []Cmder) error {
   811  			cmds = wrapMultiExec(ctx, cmds)
   812  			return c.processTxPipelineHook(ctx, cmds)
   813  		},
   814  	}
   815  	pipe.init()
   816  	return &pipe
   817  }
   818  
   819  func (c *Ring) generalProcessPipeline(
   820  	ctx context.Context, cmds []Cmder, tx bool,
   821  ) error {
   822  	if tx {
   823  		// Trim multi .. exec.
   824  		cmds = cmds[1 : len(cmds)-1]
   825  	}
   826  
   827  	cmdsMap := make(map[string][]Cmder)
   828  
   829  	for _, cmd := range cmds {
   830  		hash := cmd.stringArg(cmdFirstKeyPos(cmd))
   831  		if hash != "" {
   832  			hash = c.sharding.Hash(hash)
   833  		}
   834  		cmdsMap[hash] = append(cmdsMap[hash], cmd)
   835  	}
   836  
   837  	var wg sync.WaitGroup
   838  	errs := make(chan error, len(cmdsMap))
   839  
   840  	for hash, cmds := range cmdsMap {
   841  		wg.Add(1)
   842  		go func(hash string, cmds []Cmder) {
   843  			defer wg.Done()
   844  
   845  			// TODO: retry?
   846  			shard, err := c.sharding.GetByName(hash)
   847  			if err != nil {
   848  				setCmdsErr(cmds, err)
   849  				return
   850  			}
   851  
   852  			hook := shard.Client.processPipelineHook
   853  			if tx {
   854  				cmds = wrapMultiExec(ctx, cmds)
   855  				hook = shard.Client.processTxPipelineHook
   856  			}
   857  
   858  			if err = hook(ctx, cmds); err != nil {
   859  				errs <- err
   860  			}
   861  		}(hash, cmds)
   862  	}
   863  
   864  	wg.Wait()
   865  	close(errs)
   866  
   867  	if err := <-errs; err != nil {
   868  		return err
   869  	}
   870  	return cmdsFirstErr(cmds)
   871  }
   872  
   873  func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
   874  	if len(keys) == 0 {
   875  		return fmt.Errorf("redis: Watch requires at least one key")
   876  	}
   877  
   878  	var shards []*ringShard
   879  
   880  	for _, key := range keys {
   881  		if key != "" {
   882  			shard, err := c.sharding.GetByKey(key)
   883  			if err != nil {
   884  				return err
   885  			}
   886  
   887  			shards = append(shards, shard)
   888  		}
   889  	}
   890  
   891  	if len(shards) == 0 {
   892  		return fmt.Errorf("redis: Watch requires at least one shard")
   893  	}
   894  
   895  	if len(shards) > 1 {
   896  		for _, shard := range shards[1:] {
   897  			if shard.Client != shards[0].Client {
   898  				err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
   899  				return err
   900  			}
   901  		}
   902  	}
   903  
   904  	return shards[0].Client.Watch(ctx, fn, keys...)
   905  }
   906  
   907  // Close closes the ring client, releasing any open resources.
   908  //
   909  // It is rare to Close a Ring, as the Ring is meant to be long-lived
   910  // and shared between many goroutines.
   911  func (c *Ring) Close() error {
   912  	c.heartbeatCancelFn()
   913  
   914  	return c.sharding.Close()
   915  }
   916  
   917  // GetShardClients returns a list of all shard clients in the ring.
   918  // This can be used to create dedicated connections (e.g., PubSub) for each shard.
   919  func (c *Ring) GetShardClients() []*Client {
   920  	shards := c.sharding.List()
   921  	clients := make([]*Client, 0, len(shards))
   922  	for _, shard := range shards {
   923  		if shard.IsUp() {
   924  			clients = append(clients, shard.Client)
   925  		}
   926  	}
   927  	return clients
   928  }
   929  
   930  // GetShardClientForKey returns the shard client that would handle the given key.
   931  // This can be used to determine which shard a particular key/channel would be routed to.
   932  func (c *Ring) GetShardClientForKey(key string) (*Client, error) {
   933  	shard, err := c.sharding.GetByKey(key)
   934  	if err != nil {
   935  		return nil, err
   936  	}
   937  	return shard.Client, nil
   938  }
   939  

View as plain text