...

Source file src/github.com/redis/go-redis/v9/redis.go

Documentation: github.com/redis/go-redis/v9

     1  package redis
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/redis/go-redis/v9/auth"
    13  	"github.com/redis/go-redis/v9/internal"
    14  	"github.com/redis/go-redis/v9/internal/hscan"
    15  	"github.com/redis/go-redis/v9/internal/pool"
    16  	"github.com/redis/go-redis/v9/internal/proto"
    17  )
    18  
    19  // Scanner internal/hscan.Scanner exposed interface.
    20  type Scanner = hscan.Scanner
    21  
    22  // Nil reply returned by Redis when key does not exist.
    23  const Nil = proto.Nil
    24  
    25  // SetLogger set custom log
    26  func SetLogger(logger internal.Logging) {
    27  	internal.Logger = logger
    28  }
    29  
    30  //------------------------------------------------------------------------------
    31  
    32  type Hook interface {
    33  	DialHook(next DialHook) DialHook
    34  	ProcessHook(next ProcessHook) ProcessHook
    35  	ProcessPipelineHook(next ProcessPipelineHook) ProcessPipelineHook
    36  }
    37  
    38  type (
    39  	DialHook            func(ctx context.Context, network, addr string) (net.Conn, error)
    40  	ProcessHook         func(ctx context.Context, cmd Cmder) error
    41  	ProcessPipelineHook func(ctx context.Context, cmds []Cmder) error
    42  )
    43  
    44  type hooksMixin struct {
    45  	hooksMu *sync.RWMutex
    46  
    47  	slice   []Hook
    48  	initial hooks
    49  	current hooks
    50  }
    51  
    52  func (hs *hooksMixin) initHooks(hooks hooks) {
    53  	hs.hooksMu = new(sync.RWMutex)
    54  	hs.initial = hooks
    55  	hs.chain()
    56  }
    57  
    58  type hooks struct {
    59  	dial       DialHook
    60  	process    ProcessHook
    61  	pipeline   ProcessPipelineHook
    62  	txPipeline ProcessPipelineHook
    63  }
    64  
    65  func (h *hooks) setDefaults() {
    66  	if h.dial == nil {
    67  		h.dial = func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil }
    68  	}
    69  	if h.process == nil {
    70  		h.process = func(ctx context.Context, cmd Cmder) error { return nil }
    71  	}
    72  	if h.pipeline == nil {
    73  		h.pipeline = func(ctx context.Context, cmds []Cmder) error { return nil }
    74  	}
    75  	if h.txPipeline == nil {
    76  		h.txPipeline = func(ctx context.Context, cmds []Cmder) error { return nil }
    77  	}
    78  }
    79  
    80  // AddHook is to add a hook to the queue.
    81  // Hook is a function executed during network connection, command execution, and pipeline,
    82  // it is a first-in-first-out stack queue (FIFO).
    83  // You need to execute the next hook in each hook, unless you want to terminate the execution of the command.
    84  // For example, you added hook-1, hook-2:
    85  //
    86  //	client.AddHook(hook-1, hook-2)
    87  //
    88  // hook-1:
    89  //
    90  //	func (Hook1) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
    91  //	 	return func(ctx context.Context, cmd Cmder) error {
    92  //		 	print("hook-1 start")
    93  //		 	next(ctx, cmd)
    94  //		 	print("hook-1 end")
    95  //		 	return nil
    96  //	 	}
    97  //	}
    98  //
    99  // hook-2:
   100  //
   101  //	func (Hook2) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
   102  //		return func(ctx context.Context, cmd redis.Cmder) error {
   103  //			print("hook-2 start")
   104  //			next(ctx, cmd)
   105  //			print("hook-2 end")
   106  //			return nil
   107  //		}
   108  //	}
   109  //
   110  // The execution sequence is:
   111  //
   112  //	hook-1 start -> hook-2 start -> exec redis cmd -> hook-2 end -> hook-1 end
   113  //
   114  // Please note: "next(ctx, cmd)" is very important, it will call the next hook,
   115  // if "next(ctx, cmd)" is not executed, the redis command will not be executed.
   116  func (hs *hooksMixin) AddHook(hook Hook) {
   117  	hs.slice = append(hs.slice, hook)
   118  	hs.chain()
   119  }
   120  
   121  func (hs *hooksMixin) chain() {
   122  	hs.initial.setDefaults()
   123  
   124  	hs.hooksMu.Lock()
   125  	defer hs.hooksMu.Unlock()
   126  
   127  	hs.current.dial = hs.initial.dial
   128  	hs.current.process = hs.initial.process
   129  	hs.current.pipeline = hs.initial.pipeline
   130  	hs.current.txPipeline = hs.initial.txPipeline
   131  
   132  	for i := len(hs.slice) - 1; i >= 0; i-- {
   133  		if wrapped := hs.slice[i].DialHook(hs.current.dial); wrapped != nil {
   134  			hs.current.dial = wrapped
   135  		}
   136  		if wrapped := hs.slice[i].ProcessHook(hs.current.process); wrapped != nil {
   137  			hs.current.process = wrapped
   138  		}
   139  		if wrapped := hs.slice[i].ProcessPipelineHook(hs.current.pipeline); wrapped != nil {
   140  			hs.current.pipeline = wrapped
   141  		}
   142  		if wrapped := hs.slice[i].ProcessPipelineHook(hs.current.txPipeline); wrapped != nil {
   143  			hs.current.txPipeline = wrapped
   144  		}
   145  	}
   146  }
   147  
   148  func (hs *hooksMixin) clone() hooksMixin {
   149  	hs.hooksMu.Lock()
   150  	defer hs.hooksMu.Unlock()
   151  
   152  	clone := *hs
   153  	l := len(clone.slice)
   154  	clone.slice = clone.slice[:l:l]
   155  	clone.hooksMu = new(sync.RWMutex)
   156  	return clone
   157  }
   158  
   159  func (hs *hooksMixin) withProcessHook(ctx context.Context, cmd Cmder, hook ProcessHook) error {
   160  	for i := len(hs.slice) - 1; i >= 0; i-- {
   161  		if wrapped := hs.slice[i].ProcessHook(hook); wrapped != nil {
   162  			hook = wrapped
   163  		}
   164  	}
   165  	return hook(ctx, cmd)
   166  }
   167  
   168  func (hs *hooksMixin) withProcessPipelineHook(
   169  	ctx context.Context, cmds []Cmder, hook ProcessPipelineHook,
   170  ) error {
   171  	for i := len(hs.slice) - 1; i >= 0; i-- {
   172  		if wrapped := hs.slice[i].ProcessPipelineHook(hook); wrapped != nil {
   173  			hook = wrapped
   174  		}
   175  	}
   176  	return hook(ctx, cmds)
   177  }
   178  
   179  func (hs *hooksMixin) dialHook(ctx context.Context, network, addr string) (net.Conn, error) {
   180  	// Access to hs.current is guarded by a read-only lock since it may be mutated by AddHook(...)
   181  	// while this dialer is concurrently accessed by the background connection pool population
   182  	// routine when MinIdleConns > 0.
   183  	hs.hooksMu.RLock()
   184  	current := hs.current
   185  	hs.hooksMu.RUnlock()
   186  
   187  	return current.dial(ctx, network, addr)
   188  }
   189  
   190  func (hs *hooksMixin) processHook(ctx context.Context, cmd Cmder) error {
   191  	return hs.current.process(ctx, cmd)
   192  }
   193  
   194  func (hs *hooksMixin) processPipelineHook(ctx context.Context, cmds []Cmder) error {
   195  	return hs.current.pipeline(ctx, cmds)
   196  }
   197  
   198  func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) error {
   199  	return hs.current.txPipeline(ctx, cmds)
   200  }
   201  
   202  //------------------------------------------------------------------------------
   203  
   204  type baseClient struct {
   205  	opt      *Options
   206  	connPool pool.Pooler
   207  	hooksMixin
   208  
   209  	onClose func() error // hook called when client is closed
   210  }
   211  
   212  func (c *baseClient) clone() *baseClient {
   213  	clone := *c
   214  	return &clone
   215  }
   216  
   217  func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
   218  	opt := c.opt.clone()
   219  	opt.ReadTimeout = timeout
   220  	opt.WriteTimeout = timeout
   221  
   222  	clone := c.clone()
   223  	clone.opt = opt
   224  
   225  	return clone
   226  }
   227  
   228  func (c *baseClient) String() string {
   229  	return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
   230  }
   231  
   232  func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
   233  	cn, err := c.connPool.NewConn(ctx)
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	err = c.initConn(ctx, cn)
   239  	if err != nil {
   240  		_ = c.connPool.CloseConn(cn)
   241  		return nil, err
   242  	}
   243  
   244  	return cn, nil
   245  }
   246  
   247  func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
   248  	if c.opt.Limiter != nil {
   249  		err := c.opt.Limiter.Allow()
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  	}
   254  
   255  	cn, err := c._getConn(ctx)
   256  	if err != nil {
   257  		if c.opt.Limiter != nil {
   258  			c.opt.Limiter.ReportResult(err)
   259  		}
   260  		return nil, err
   261  	}
   262  
   263  	return cn, nil
   264  }
   265  
   266  func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
   267  	cn, err := c.connPool.Get(ctx)
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	if cn.Inited {
   273  		return cn, nil
   274  	}
   275  
   276  	if err := c.initConn(ctx, cn); err != nil {
   277  		c.connPool.Remove(ctx, cn, err)
   278  		if err := errors.Unwrap(err); err != nil {
   279  			return nil, err
   280  		}
   281  		return nil, err
   282  	}
   283  
   284  	return cn, nil
   285  }
   286  
   287  func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
   288  	return auth.NewReAuthCredentialsListener(
   289  		c.reAuthConnection(poolCn),
   290  		c.onAuthenticationErr(poolCn),
   291  	)
   292  }
   293  
   294  func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
   295  	return func(credentials auth.Credentials) error {
   296  		var err error
   297  		username, password := credentials.BasicAuth()
   298  		ctx := context.Background()
   299  		connPool := pool.NewSingleConnPool(c.connPool, poolCn)
   300  		// hooksMixin are intentionally empty here
   301  		cn := newConn(c.opt, connPool, nil)
   302  
   303  		if username != "" {
   304  			err = cn.AuthACL(ctx, username, password).Err()
   305  		} else {
   306  			err = cn.Auth(ctx, password).Err()
   307  		}
   308  		return err
   309  	}
   310  }
   311  func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
   312  	return func(err error) {
   313  		if err != nil {
   314  			if isBadConn(err, false, c.opt.Addr) {
   315  				// Close the connection to force a reconnection.
   316  				err := c.connPool.CloseConn(poolCn)
   317  				if err != nil {
   318  					internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err)
   319  					// try to close the network connection directly
   320  					// so that no resource is leaked
   321  					err := poolCn.Close()
   322  					if err != nil {
   323  						internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err)
   324  					}
   325  				}
   326  			}
   327  			internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err)
   328  		}
   329  	}
   330  }
   331  
   332  func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
   333  	onClose := c.onClose
   334  	return func() error {
   335  		var firstErr error
   336  		err := newOnClose()
   337  		// Even if we have an error we would like to execute the onClose hook
   338  		// if it exists. We will return the first error that occurred.
   339  		// This is to keep error handling consistent with the rest of the code.
   340  		if err != nil {
   341  			firstErr = err
   342  		}
   343  		if onClose != nil {
   344  			err = onClose()
   345  			if err != nil && firstErr == nil {
   346  				firstErr = err
   347  			}
   348  		}
   349  		return firstErr
   350  	}
   351  }
   352  
   353  func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
   354  	if cn.Inited {
   355  		return nil
   356  	}
   357  
   358  	var err error
   359  	cn.Inited = true
   360  	connPool := pool.NewSingleConnPool(c.connPool, cn)
   361  	conn := newConn(c.opt, connPool, &c.hooksMixin)
   362  
   363  	username, password := "", ""
   364  	if c.opt.StreamingCredentialsProvider != nil {
   365  		credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
   366  			Subscribe(c.newReAuthCredentialsListener(cn))
   367  		if err != nil {
   368  			return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
   369  		}
   370  		c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
   371  		cn.SetOnClose(unsubscribeFromCredentialsProvider)
   372  		username, password = credentials.BasicAuth()
   373  	} else if c.opt.CredentialsProviderContext != nil {
   374  		username, password, err = c.opt.CredentialsProviderContext(ctx)
   375  		if err != nil {
   376  			return fmt.Errorf("failed to get credentials from context provider: %w", err)
   377  		}
   378  	} else if c.opt.CredentialsProvider != nil {
   379  		username, password = c.opt.CredentialsProvider()
   380  	} else if c.opt.Username != "" || c.opt.Password != "" {
   381  		username, password = c.opt.Username, c.opt.Password
   382  	}
   383  
   384  	// for redis-server versions that do not support the HELLO command,
   385  	// RESP2 will continue to be used.
   386  	if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
   387  		// Authentication successful with HELLO command
   388  	} else if !isRedisError(err) {
   389  		// When the server responds with the RESP protocol and the result is not a normal
   390  		// execution result of the HELLO command, we consider it to be an indication that
   391  		// the server does not support the HELLO command.
   392  		// The server may be a redis-server that does not support the HELLO command,
   393  		// or it could be DragonflyDB or a third-party redis-proxy. They all respond
   394  		// with different error string results for unsupported commands, making it
   395  		// difficult to rely on error strings to determine all results.
   396  		return err
   397  	} else if password != "" {
   398  		// Try legacy AUTH command if HELLO failed
   399  		if username != "" {
   400  			err = conn.AuthACL(ctx, username, password).Err()
   401  		} else {
   402  			err = conn.Auth(ctx, password).Err()
   403  		}
   404  		if err != nil {
   405  			return fmt.Errorf("failed to authenticate: %w", err)
   406  		}
   407  	}
   408  
   409  	_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
   410  		if c.opt.DB > 0 {
   411  			pipe.Select(ctx, c.opt.DB)
   412  		}
   413  
   414  		if c.opt.readOnly {
   415  			pipe.ReadOnly(ctx)
   416  		}
   417  
   418  		if c.opt.ClientName != "" {
   419  			pipe.ClientSetName(ctx, c.opt.ClientName)
   420  		}
   421  
   422  		return nil
   423  	})
   424  	if err != nil {
   425  		return fmt.Errorf("failed to initialize connection options: %w", err)
   426  	}
   427  
   428  	if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
   429  		libName := ""
   430  		libVer := Version()
   431  		if c.opt.IdentitySuffix != "" {
   432  			libName = c.opt.IdentitySuffix
   433  		}
   434  		p := conn.Pipeline()
   435  		p.ClientSetInfo(ctx, WithLibraryName(libName))
   436  		p.ClientSetInfo(ctx, WithLibraryVersion(libVer))
   437  		// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
   438  		// out of order responses later on.
   439  		if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
   440  			return err
   441  		}
   442  	}
   443  
   444  	if c.opt.OnConnect != nil {
   445  		return c.opt.OnConnect(ctx, conn)
   446  	}
   447  
   448  	return nil
   449  }
   450  
   451  func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
   452  	if c.opt.Limiter != nil {
   453  		c.opt.Limiter.ReportResult(err)
   454  	}
   455  
   456  	if isBadConn(err, false, c.opt.Addr) {
   457  		c.connPool.Remove(ctx, cn, err)
   458  	} else {
   459  		c.connPool.Put(ctx, cn)
   460  	}
   461  }
   462  
   463  func (c *baseClient) withConn(
   464  	ctx context.Context, fn func(context.Context, *pool.Conn) error,
   465  ) error {
   466  	cn, err := c.getConn(ctx)
   467  	if err != nil {
   468  		return err
   469  	}
   470  
   471  	var fnErr error
   472  	defer func() {
   473  		c.releaseConn(ctx, cn, fnErr)
   474  	}()
   475  
   476  	fnErr = fn(ctx, cn)
   477  
   478  	return fnErr
   479  }
   480  
   481  func (c *baseClient) dial(ctx context.Context, network, addr string) (net.Conn, error) {
   482  	return c.opt.Dialer(ctx, network, addr)
   483  }
   484  
   485  func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
   486  	var lastErr error
   487  	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
   488  		attempt := attempt
   489  
   490  		retry, err := c._process(ctx, cmd, attempt)
   491  		if err == nil || !retry {
   492  			return err
   493  		}
   494  
   495  		lastErr = err
   496  	}
   497  	return lastErr
   498  }
   499  
   500  func (c *baseClient) assertUnstableCommand(cmd Cmder) bool {
   501  	switch cmd.(type) {
   502  	case *AggregateCmd, *FTInfoCmd, *FTSpellCheckCmd, *FTSearchCmd, *FTSynDumpCmd:
   503  		if c.opt.UnstableResp3 {
   504  			return true
   505  		} else {
   506  			panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 .  See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.")
   507  		}
   508  	default:
   509  		return false
   510  	}
   511  }
   512  
   513  func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) {
   514  	if attempt > 0 {
   515  		if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
   516  			return false, err
   517  		}
   518  	}
   519  
   520  	retryTimeout := uint32(0)
   521  	if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
   522  		if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
   523  			return writeCmd(wr, cmd)
   524  		}); err != nil {
   525  			atomic.StoreUint32(&retryTimeout, 1)
   526  			return err
   527  		}
   528  		readReplyFunc := cmd.readReply
   529  		// Apply unstable RESP3 search module.
   530  		if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
   531  			readReplyFunc = cmd.readRawReply
   532  		}
   533  		if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), readReplyFunc); err != nil {
   534  			if cmd.readTimeout() == nil {
   535  				atomic.StoreUint32(&retryTimeout, 1)
   536  			} else {
   537  				atomic.StoreUint32(&retryTimeout, 0)
   538  			}
   539  			return err
   540  		}
   541  
   542  		return nil
   543  	}); err != nil {
   544  		retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
   545  		return retry, err
   546  	}
   547  
   548  	return false, nil
   549  }
   550  
   551  func (c *baseClient) retryBackoff(attempt int) time.Duration {
   552  	return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
   553  }
   554  
   555  func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
   556  	if timeout := cmd.readTimeout(); timeout != nil {
   557  		t := *timeout
   558  		if t == 0 {
   559  			return 0
   560  		}
   561  		return t + 10*time.Second
   562  	}
   563  	return c.opt.ReadTimeout
   564  }
   565  
   566  // context returns the context for the current connection.
   567  // If the context timeout is enabled, it returns the original context.
   568  // Otherwise, it returns a new background context.
   569  func (c *baseClient) context(ctx context.Context) context.Context {
   570  	if c.opt.ContextTimeoutEnabled {
   571  		return ctx
   572  	}
   573  	return context.Background()
   574  }
   575  
   576  // Close closes the client, releasing any open resources.
   577  //
   578  // It is rare to Close a Client, as the Client is meant to be
   579  // long-lived and shared between many goroutines.
   580  func (c *baseClient) Close() error {
   581  	var firstErr error
   582  	if c.onClose != nil {
   583  		if err := c.onClose(); err != nil {
   584  			firstErr = err
   585  		}
   586  	}
   587  	if err := c.connPool.Close(); err != nil && firstErr == nil {
   588  		firstErr = err
   589  	}
   590  	return firstErr
   591  }
   592  
   593  func (c *baseClient) getAddr() string {
   594  	return c.opt.Addr
   595  }
   596  
   597  func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
   598  	if err := c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds); err != nil {
   599  		return err
   600  	}
   601  	return cmdsFirstErr(cmds)
   602  }
   603  
   604  func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
   605  	if err := c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds); err != nil {
   606  		return err
   607  	}
   608  	return cmdsFirstErr(cmds)
   609  }
   610  
   611  type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
   612  
   613  func (c *baseClient) generalProcessPipeline(
   614  	ctx context.Context, cmds []Cmder, p pipelineProcessor,
   615  ) error {
   616  	var lastErr error
   617  	for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
   618  		if attempt > 0 {
   619  			if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
   620  				setCmdsErr(cmds, err)
   621  				return err
   622  			}
   623  		}
   624  
   625  		// Enable retries by default to retry dial errors returned by withConn.
   626  		canRetry := true
   627  		lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
   628  			var err error
   629  			canRetry, err = p(ctx, cn, cmds)
   630  			return err
   631  		})
   632  		if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
   633  			setCmdsErr(cmds, lastErr)
   634  			return lastErr
   635  		}
   636  	}
   637  	return lastErr
   638  }
   639  
   640  func (c *baseClient) pipelineProcessCmds(
   641  	ctx context.Context, cn *pool.Conn, cmds []Cmder,
   642  ) (bool, error) {
   643  	if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
   644  		return writeCmds(wr, cmds)
   645  	}); err != nil {
   646  		setCmdsErr(cmds, err)
   647  		return true, err
   648  	}
   649  
   650  	if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
   651  		return pipelineReadCmds(rd, cmds)
   652  	}); err != nil {
   653  		return true, err
   654  	}
   655  
   656  	return false, nil
   657  }
   658  
   659  func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
   660  	for i, cmd := range cmds {
   661  		err := cmd.readReply(rd)
   662  		cmd.SetErr(err)
   663  		if err != nil && !isRedisError(err) {
   664  			setCmdsErr(cmds[i+1:], err)
   665  			return err
   666  		}
   667  	}
   668  	// Retry errors like "LOADING redis is loading the dataset in memory".
   669  	return cmds[0].Err()
   670  }
   671  
   672  func (c *baseClient) txPipelineProcessCmds(
   673  	ctx context.Context, cn *pool.Conn, cmds []Cmder,
   674  ) (bool, error) {
   675  	if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
   676  		return writeCmds(wr, cmds)
   677  	}); err != nil {
   678  		setCmdsErr(cmds, err)
   679  		return true, err
   680  	}
   681  
   682  	if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
   683  		statusCmd := cmds[0].(*StatusCmd)
   684  		// Trim multi and exec.
   685  		trimmedCmds := cmds[1 : len(cmds)-1]
   686  
   687  		if err := txPipelineReadQueued(rd, statusCmd, trimmedCmds); err != nil {
   688  			setCmdsErr(cmds, err)
   689  			return err
   690  		}
   691  
   692  		return pipelineReadCmds(rd, trimmedCmds)
   693  	}); err != nil {
   694  		return false, err
   695  	}
   696  
   697  	return false, nil
   698  }
   699  
   700  func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
   701  	// Parse +OK.
   702  	if err := statusCmd.readReply(rd); err != nil {
   703  		return err
   704  	}
   705  
   706  	// Parse +QUEUED.
   707  	for _, cmd := range cmds {
   708  		if err := statusCmd.readReply(rd); err != nil {
   709  			cmd.SetErr(err)
   710  			if !isRedisError(err) {
   711  				return err
   712  			}
   713  		}
   714  	}
   715  
   716  	// Parse number of replies.
   717  	line, err := rd.ReadLine()
   718  	if err != nil {
   719  		if err == Nil {
   720  			err = TxFailedErr
   721  		}
   722  		return err
   723  	}
   724  
   725  	if line[0] != proto.RespArray {
   726  		return fmt.Errorf("redis: expected '*', but got line %q", line)
   727  	}
   728  
   729  	return nil
   730  }
   731  
   732  //------------------------------------------------------------------------------
   733  
   734  // Client is a Redis client representing a pool of zero or more underlying connections.
   735  // It's safe for concurrent use by multiple goroutines.
   736  //
   737  // Client creates and frees connections automatically; it also maintains a free pool
   738  // of idle connections. You can control the pool size with Config.PoolSize option.
   739  type Client struct {
   740  	*baseClient
   741  	cmdable
   742  }
   743  
   744  // NewClient returns a client to the Redis Server specified by Options.
   745  func NewClient(opt *Options) *Client {
   746  	if opt == nil {
   747  		panic("redis: NewClient nil options")
   748  	}
   749  	opt.init()
   750  
   751  	c := Client{
   752  		baseClient: &baseClient{
   753  			opt: opt,
   754  		},
   755  	}
   756  	c.init()
   757  	c.connPool = newConnPool(opt, c.dialHook)
   758  
   759  	return &c
   760  }
   761  
   762  func (c *Client) init() {
   763  	c.cmdable = c.Process
   764  	c.initHooks(hooks{
   765  		dial:       c.baseClient.dial,
   766  		process:    c.baseClient.process,
   767  		pipeline:   c.baseClient.processPipeline,
   768  		txPipeline: c.baseClient.processTxPipeline,
   769  	})
   770  }
   771  
   772  func (c *Client) WithTimeout(timeout time.Duration) *Client {
   773  	clone := *c
   774  	clone.baseClient = c.baseClient.withTimeout(timeout)
   775  	clone.init()
   776  	return &clone
   777  }
   778  
   779  func (c *Client) Conn() *Conn {
   780  	return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin)
   781  }
   782  
   783  func (c *Client) Process(ctx context.Context, cmd Cmder) error {
   784  	err := c.processHook(ctx, cmd)
   785  	cmd.SetErr(err)
   786  	return err
   787  }
   788  
   789  // Options returns read-only Options that were used to create the client.
   790  func (c *Client) Options() *Options {
   791  	return c.opt
   792  }
   793  
   794  type PoolStats pool.Stats
   795  
   796  // PoolStats returns connection pool stats.
   797  func (c *Client) PoolStats() *PoolStats {
   798  	stats := c.connPool.Stats()
   799  	return (*PoolStats)(stats)
   800  }
   801  
   802  func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
   803  	return c.Pipeline().Pipelined(ctx, fn)
   804  }
   805  
   806  func (c *Client) Pipeline() Pipeliner {
   807  	pipe := Pipeline{
   808  		exec: pipelineExecer(c.processPipelineHook),
   809  	}
   810  	pipe.init()
   811  	return &pipe
   812  }
   813  
   814  func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
   815  	return c.TxPipeline().Pipelined(ctx, fn)
   816  }
   817  
   818  // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
   819  func (c *Client) TxPipeline() Pipeliner {
   820  	pipe := Pipeline{
   821  		exec: func(ctx context.Context, cmds []Cmder) error {
   822  			cmds = wrapMultiExec(ctx, cmds)
   823  			return c.processTxPipelineHook(ctx, cmds)
   824  		},
   825  	}
   826  	pipe.init()
   827  	return &pipe
   828  }
   829  
   830  func (c *Client) pubSub() *PubSub {
   831  	pubsub := &PubSub{
   832  		opt: c.opt,
   833  
   834  		newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
   835  			return c.newConn(ctx)
   836  		},
   837  		closeConn: c.connPool.CloseConn,
   838  	}
   839  	pubsub.init()
   840  	return pubsub
   841  }
   842  
   843  // Subscribe subscribes the client to the specified channels.
   844  // Channels can be omitted to create empty subscription.
   845  // Note that this method does not wait on a response from Redis, so the
   846  // subscription may not be active immediately. To force the connection to wait,
   847  // you may call the Receive() method on the returned *PubSub like so:
   848  //
   849  //	sub := client.Subscribe(queryResp)
   850  //	iface, err := sub.Receive()
   851  //	if err != nil {
   852  //	    // handle error
   853  //	}
   854  //
   855  //	// Should be *Subscription, but others are possible if other actions have been
   856  //	// taken on sub since it was created.
   857  //	switch iface.(type) {
   858  //	case *Subscription:
   859  //	    // subscribe succeeded
   860  //	case *Message:
   861  //	    // received first message
   862  //	case *Pong:
   863  //	    // pong received
   864  //	default:
   865  //	    // handle error
   866  //	}
   867  //
   868  //	ch := sub.Channel()
   869  func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
   870  	pubsub := c.pubSub()
   871  	if len(channels) > 0 {
   872  		_ = pubsub.Subscribe(ctx, channels...)
   873  	}
   874  	return pubsub
   875  }
   876  
   877  // PSubscribe subscribes the client to the given patterns.
   878  // Patterns can be omitted to create empty subscription.
   879  func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
   880  	pubsub := c.pubSub()
   881  	if len(channels) > 0 {
   882  		_ = pubsub.PSubscribe(ctx, channels...)
   883  	}
   884  	return pubsub
   885  }
   886  
   887  // SSubscribe Subscribes the client to the specified shard channels.
   888  // Channels can be omitted to create empty subscription.
   889  func (c *Client) SSubscribe(ctx context.Context, channels ...string) *PubSub {
   890  	pubsub := c.pubSub()
   891  	if len(channels) > 0 {
   892  		_ = pubsub.SSubscribe(ctx, channels...)
   893  	}
   894  	return pubsub
   895  }
   896  
   897  //------------------------------------------------------------------------------
   898  
   899  // Conn represents a single Redis connection rather than a pool of connections.
   900  // Prefer running commands from Client unless there is a specific need
   901  // for a continuous single Redis connection.
   902  type Conn struct {
   903  	baseClient
   904  	cmdable
   905  	statefulCmdable
   906  }
   907  
   908  // newConn is a helper func to create a new Conn instance.
   909  // the Conn instance is not thread-safe and should not be shared between goroutines.
   910  // the parentHooks will be cloned, no need to clone before passing it.
   911  func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn {
   912  	c := Conn{
   913  		baseClient: baseClient{
   914  			opt:      opt,
   915  			connPool: connPool,
   916  		},
   917  	}
   918  
   919  	if parentHooks != nil {
   920  		c.hooksMixin = parentHooks.clone()
   921  	}
   922  
   923  	c.cmdable = c.Process
   924  	c.statefulCmdable = c.Process
   925  	c.initHooks(hooks{
   926  		dial:       c.baseClient.dial,
   927  		process:    c.baseClient.process,
   928  		pipeline:   c.baseClient.processPipeline,
   929  		txPipeline: c.baseClient.processTxPipeline,
   930  	})
   931  
   932  	return &c
   933  }
   934  
   935  func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
   936  	err := c.processHook(ctx, cmd)
   937  	cmd.SetErr(err)
   938  	return err
   939  }
   940  
   941  func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
   942  	return c.Pipeline().Pipelined(ctx, fn)
   943  }
   944  
   945  func (c *Conn) Pipeline() Pipeliner {
   946  	pipe := Pipeline{
   947  		exec: c.processPipelineHook,
   948  	}
   949  	pipe.init()
   950  	return &pipe
   951  }
   952  
   953  func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
   954  	return c.TxPipeline().Pipelined(ctx, fn)
   955  }
   956  
   957  // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
   958  func (c *Conn) TxPipeline() Pipeliner {
   959  	pipe := Pipeline{
   960  		exec: func(ctx context.Context, cmds []Cmder) error {
   961  			cmds = wrapMultiExec(ctx, cmds)
   962  			return c.processTxPipelineHook(ctx, cmds)
   963  		},
   964  	}
   965  	pipe.init()
   966  	return &pipe
   967  }
   968  

View as plain text