...

Source file src/github.com/redis/go-redis/v9/redis_test.go

Documentation: github.com/redis/go-redis/v9

     1  package redis_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	. "github.com/bsm/ginkgo/v2"
    14  	. "github.com/bsm/gomega"
    15  
    16  	"github.com/redis/go-redis/v9"
    17  	"github.com/redis/go-redis/v9/auth"
    18  )
    19  
    20  type redisHookError struct{}
    21  
    22  var _ redis.Hook = redisHookError{}
    23  
    24  func (redisHookError) DialHook(hook redis.DialHook) redis.DialHook {
    25  	return hook
    26  }
    27  
    28  func (redisHookError) ProcessHook(hook redis.ProcessHook) redis.ProcessHook {
    29  	return func(ctx context.Context, cmd redis.Cmder) error {
    30  		return errors.New("hook error")
    31  	}
    32  }
    33  
    34  func (redisHookError) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.ProcessPipelineHook {
    35  	return hook
    36  }
    37  
    38  func TestHookError(t *testing.T) {
    39  	rdb := redis.NewClient(&redis.Options{
    40  		Addr: ":6379",
    41  	})
    42  	rdb.AddHook(redisHookError{})
    43  
    44  	err := rdb.Ping(ctx).Err()
    45  	if err == nil {
    46  		t.Fatalf("got nil, expected an error")
    47  	}
    48  
    49  	wanted := "hook error"
    50  	if err.Error() != wanted {
    51  		t.Fatalf(`got %q, wanted %q`, err, wanted)
    52  	}
    53  }
    54  
    55  //------------------------------------------------------------------------------
    56  
    57  var _ = Describe("Client", func() {
    58  	var client *redis.Client
    59  
    60  	BeforeEach(func() {
    61  		client = redis.NewClient(redisOptions())
    62  		Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
    63  	})
    64  
    65  	AfterEach(func() {
    66  		client.Close()
    67  	})
    68  
    69  	It("should Stringer", func() {
    70  		Expect(client.String()).To(Equal(fmt.Sprintf("Redis<:%s db:0>", redisPort)))
    71  	})
    72  
    73  	It("supports context", func() {
    74  		ctx, cancel := context.WithCancel(ctx)
    75  		cancel()
    76  
    77  		err := client.Ping(ctx).Err()
    78  		Expect(err).To(MatchError("context canceled"))
    79  	})
    80  
    81  	It("supports WithTimeout", Label("NonRedisEnterprise"), func() {
    82  		err := client.ClientPause(ctx, time.Second).Err()
    83  		Expect(err).NotTo(HaveOccurred())
    84  
    85  		err = client.WithTimeout(10 * time.Millisecond).Ping(ctx).Err()
    86  		Expect(err).To(HaveOccurred())
    87  
    88  		err = client.Ping(ctx).Err()
    89  		Expect(err).NotTo(HaveOccurred())
    90  	})
    91  
    92  	It("do", func() {
    93  		val, err := client.Do(ctx, "ping").Result()
    94  		Expect(err).NotTo(HaveOccurred())
    95  		Expect(val).To(Equal("PONG"))
    96  	})
    97  
    98  	It("should ping", func() {
    99  		val, err := client.Ping(ctx).Result()
   100  		Expect(err).NotTo(HaveOccurred())
   101  		Expect(val).To(Equal("PONG"))
   102  	})
   103  
   104  	It("should return pool stats", func() {
   105  		Expect(client.PoolStats()).To(BeAssignableToTypeOf(&redis.PoolStats{}))
   106  	})
   107  
   108  	It("should support custom dialers", func() {
   109  		custom := redis.NewClient(&redis.Options{
   110  			Network: "tcp",
   111  			Addr:    redisAddr,
   112  			Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
   113  				var d net.Dialer
   114  				return d.DialContext(ctx, network, addr)
   115  			},
   116  		})
   117  
   118  		val, err := custom.Ping(ctx).Result()
   119  		Expect(err).NotTo(HaveOccurred())
   120  		Expect(val).To(Equal("PONG"))
   121  		Expect(custom.Close()).NotTo(HaveOccurred())
   122  	})
   123  
   124  	It("should close", func() {
   125  		Expect(client.Close()).NotTo(HaveOccurred())
   126  		err := client.Ping(ctx).Err()
   127  		Expect(err).To(MatchError("redis: client is closed"))
   128  	})
   129  
   130  	It("should close pubsub without closing the client", func() {
   131  		pubsub := client.Subscribe(ctx)
   132  		Expect(pubsub.Close()).NotTo(HaveOccurred())
   133  
   134  		_, err := pubsub.Receive(ctx)
   135  		Expect(err).To(MatchError("redis: client is closed"))
   136  		Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
   137  	})
   138  
   139  	It("should close Tx without closing the client", func() {
   140  		err := client.Watch(ctx, func(tx *redis.Tx) error {
   141  			_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
   142  				pipe.Ping(ctx)
   143  				return nil
   144  			})
   145  			return err
   146  		})
   147  		Expect(err).NotTo(HaveOccurred())
   148  
   149  		Expect(client.Ping(ctx).Err()).NotTo(HaveOccurred())
   150  	})
   151  
   152  	It("should close pubsub when client is closed", func() {
   153  		pubsub := client.Subscribe(ctx)
   154  		Expect(client.Close()).NotTo(HaveOccurred())
   155  
   156  		_, err := pubsub.Receive(ctx)
   157  		Expect(err).To(MatchError("redis: client is closed"))
   158  
   159  		Expect(pubsub.Close()).NotTo(HaveOccurred())
   160  	})
   161  
   162  	It("should select DB", Label("NonRedisEnterprise"), func() {
   163  		db2 := redis.NewClient(&redis.Options{
   164  			Addr: redisAddr,
   165  			DB:   2,
   166  		})
   167  		Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
   168  		Expect(db2.Get(ctx, "db").Err()).To(Equal(redis.Nil))
   169  		Expect(db2.Set(ctx, "db", 2, 0).Err()).NotTo(HaveOccurred())
   170  
   171  		n, err := db2.Get(ctx, "db").Int64()
   172  		Expect(err).NotTo(HaveOccurred())
   173  		Expect(n).To(Equal(int64(2)))
   174  
   175  		Expect(client.Get(ctx, "db").Err()).To(Equal(redis.Nil))
   176  
   177  		Expect(db2.FlushDB(ctx).Err()).NotTo(HaveOccurred())
   178  		Expect(db2.Close()).NotTo(HaveOccurred())
   179  	})
   180  
   181  	It("should client setname", func() {
   182  		opt := redisOptions()
   183  		opt.ClientName = "hi"
   184  		db := redis.NewClient(opt)
   185  
   186  		defer func() {
   187  			Expect(db.Close()).NotTo(HaveOccurred())
   188  		}()
   189  
   190  		Expect(db.Ping(ctx).Err()).NotTo(HaveOccurred())
   191  		val, err := db.ClientList(ctx).Result()
   192  		Expect(err).NotTo(HaveOccurred())
   193  		Expect(val).Should(ContainSubstring("name=hi"))
   194  	})
   195  
   196  	It("should attempt to set client name in HELLO", func() {
   197  		opt := redisOptions()
   198  		opt.ClientName = "hi"
   199  		db := redis.NewClient(opt)
   200  
   201  		defer func() {
   202  			Expect(db.Close()).NotTo(HaveOccurred())
   203  		}()
   204  
   205  		// Client name should be already set on any successfully initialized connection
   206  		name, err := db.ClientGetName(ctx).Result()
   207  		Expect(err).NotTo(HaveOccurred())
   208  		Expect(name).Should(Equal("hi"))
   209  
   210  		// HELLO should be able to explicitly overwrite the client name
   211  		conn := db.Conn()
   212  		hello, err := conn.Hello(ctx, 3, "", "", "hi2").Result()
   213  		Expect(err).NotTo(HaveOccurred())
   214  		Expect(hello["proto"]).Should(Equal(int64(3)))
   215  		name, err = conn.ClientGetName(ctx).Result()
   216  		Expect(err).NotTo(HaveOccurred())
   217  		Expect(name).Should(Equal("hi2"))
   218  		err = conn.Close()
   219  		Expect(err).NotTo(HaveOccurred())
   220  	})
   221  
   222  	It("should client PROTO 2", func() {
   223  		opt := redisOptions()
   224  		opt.Protocol = 2
   225  		db := redis.NewClient(opt)
   226  
   227  		defer func() {
   228  			Expect(db.Close()).NotTo(HaveOccurred())
   229  		}()
   230  
   231  		val, err := db.Do(ctx, "HELLO").Result()
   232  		Expect(err).NotTo(HaveOccurred())
   233  		Expect(val).Should(ContainElements("proto", int64(2)))
   234  	})
   235  
   236  	It("should client PROTO 3", func() {
   237  		opt := redisOptions()
   238  		db := redis.NewClient(opt)
   239  
   240  		defer func() {
   241  			Expect(db.Close()).NotTo(HaveOccurred())
   242  		}()
   243  
   244  		val, err := db.Do(ctx, "HELLO").Result()
   245  		Expect(err).NotTo(HaveOccurred())
   246  		Expect(val).Should(HaveKeyWithValue("proto", int64(3)))
   247  	})
   248  
   249  	It("processes custom commands", func() {
   250  		cmd := redis.NewCmd(ctx, "PING")
   251  		_ = client.Process(ctx, cmd)
   252  
   253  		// Flush buffers.
   254  		Expect(client.Echo(ctx, "hello").Err()).NotTo(HaveOccurred())
   255  
   256  		Expect(cmd.Err()).NotTo(HaveOccurred())
   257  		Expect(cmd.Val()).To(Equal("PONG"))
   258  	})
   259  
   260  	It("should retry command on network error", func() {
   261  		Expect(client.Close()).NotTo(HaveOccurred())
   262  
   263  		client = redis.NewClient(&redis.Options{
   264  			Addr:       redisAddr,
   265  			MaxRetries: 1,
   266  		})
   267  
   268  		// Put bad connection in the pool.
   269  		cn, err := client.Pool().Get(ctx)
   270  		Expect(err).NotTo(HaveOccurred())
   271  
   272  		cn.SetNetConn(&badConn{})
   273  		client.Pool().Put(ctx, cn)
   274  
   275  		err = client.Ping(ctx).Err()
   276  		Expect(err).NotTo(HaveOccurred())
   277  	})
   278  
   279  	It("should retry with backoff", func() {
   280  		clientNoRetry := redis.NewClient(&redis.Options{
   281  			Addr:       ":1234",
   282  			MaxRetries: -1,
   283  		})
   284  		defer clientNoRetry.Close()
   285  
   286  		clientRetry := redis.NewClient(&redis.Options{
   287  			Addr:            ":1234",
   288  			MaxRetries:      5,
   289  			MaxRetryBackoff: 128 * time.Millisecond,
   290  		})
   291  		defer clientRetry.Close()
   292  
   293  		startNoRetry := time.Now()
   294  		err := clientNoRetry.Ping(ctx).Err()
   295  		Expect(err).To(HaveOccurred())
   296  		elapseNoRetry := time.Since(startNoRetry)
   297  
   298  		startRetry := time.Now()
   299  		err = clientRetry.Ping(ctx).Err()
   300  		Expect(err).To(HaveOccurred())
   301  		elapseRetry := time.Since(startRetry)
   302  
   303  		Expect(elapseRetry).To(BeNumerically(">", elapseNoRetry, 10*time.Millisecond))
   304  	})
   305  
   306  	It("should update conn.UsedAt on read/write", func() {
   307  		cn, err := client.Pool().Get(context.Background())
   308  		Expect(err).NotTo(HaveOccurred())
   309  		Expect(cn.UsedAt).NotTo(BeZero())
   310  
   311  		// set cn.SetUsedAt(time) or time.Sleep(>1*time.Second)
   312  		// simulate the last time Conn was used
   313  		// time.Sleep() is not the standard sleep time
   314  		// link: https://go-review.googlesource.com/c/go/+/232298
   315  		cn.SetUsedAt(time.Now().Add(-1 * time.Second))
   316  		createdAt := cn.UsedAt()
   317  
   318  		client.Pool().Put(ctx, cn)
   319  		Expect(cn.UsedAt().Equal(createdAt)).To(BeTrue())
   320  
   321  		err = client.Ping(ctx).Err()
   322  		Expect(err).NotTo(HaveOccurred())
   323  
   324  		cn, err = client.Pool().Get(context.Background())
   325  		Expect(err).NotTo(HaveOccurred())
   326  		Expect(cn).NotTo(BeNil())
   327  		Expect(cn.UsedAt().After(createdAt)).To(BeTrue())
   328  	})
   329  
   330  	It("should process command with special chars", func() {
   331  		set := client.Set(ctx, "key", "hello1\r\nhello2\r\n", 0)
   332  		Expect(set.Err()).NotTo(HaveOccurred())
   333  		Expect(set.Val()).To(Equal("OK"))
   334  
   335  		get := client.Get(ctx, "key")
   336  		Expect(get.Err()).NotTo(HaveOccurred())
   337  		Expect(get.Val()).To(Equal("hello1\r\nhello2\r\n"))
   338  	})
   339  
   340  	It("should handle big vals", func() {
   341  		bigVal := bytes.Repeat([]byte{'*'}, 2e6)
   342  
   343  		err := client.Set(ctx, "key", bigVal, 0).Err()
   344  		Expect(err).NotTo(HaveOccurred())
   345  
   346  		// Reconnect to get new connection.
   347  		Expect(client.Close()).NotTo(HaveOccurred())
   348  		client = redis.NewClient(redisOptions())
   349  
   350  		got, err := client.Get(ctx, "key").Bytes()
   351  		Expect(err).NotTo(HaveOccurred())
   352  		Expect(got).To(Equal(bigVal))
   353  	})
   354  
   355  	It("should set and scan time", func() {
   356  		tm := time.Now()
   357  		err := client.Set(ctx, "now", tm, 0).Err()
   358  		Expect(err).NotTo(HaveOccurred())
   359  
   360  		var tm2 time.Time
   361  		err = client.Get(ctx, "now").Scan(&tm2)
   362  		Expect(err).NotTo(HaveOccurred())
   363  
   364  		Expect(tm2).To(BeTemporally("==", tm))
   365  	})
   366  
   367  	It("should set and scan durations", func() {
   368  		duration := 10 * time.Minute
   369  		err := client.Set(ctx, "duration", duration, 0).Err()
   370  		Expect(err).NotTo(HaveOccurred())
   371  
   372  		var duration2 time.Duration
   373  		err = client.Get(ctx, "duration").Scan(&duration2)
   374  		Expect(err).NotTo(HaveOccurred())
   375  
   376  		Expect(duration2).To(Equal(duration))
   377  	})
   378  
   379  	It("should Conn", func() {
   380  		err := client.Conn().Get(ctx, "this-key-does-not-exist").Err()
   381  		Expect(err).To(Equal(redis.Nil))
   382  	})
   383  
   384  	It("should set and scan net.IP", func() {
   385  		ip := net.ParseIP("192.168.1.1")
   386  		err := client.Set(ctx, "ip", ip, 0).Err()
   387  		Expect(err).NotTo(HaveOccurred())
   388  
   389  		var ip2 net.IP
   390  		err = client.Get(ctx, "ip").Scan(&ip2)
   391  		Expect(err).NotTo(HaveOccurred())
   392  
   393  		Expect(ip2).To(Equal(ip))
   394  	})
   395  })
   396  
   397  var _ = Describe("Client timeout", func() {
   398  	var opt *redis.Options
   399  	var client *redis.Client
   400  
   401  	AfterEach(func() {
   402  		Expect(client.Close()).NotTo(HaveOccurred())
   403  	})
   404  
   405  	testTimeout := func() {
   406  		It("SETINFO timeouts", func() {
   407  			conn := client.Conn()
   408  			err := conn.Ping(ctx).Err()
   409  			Expect(err).To(HaveOccurred())
   410  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   411  		})
   412  
   413  		It("Ping timeouts", func() {
   414  			err := client.Ping(ctx).Err()
   415  			Expect(err).To(HaveOccurred())
   416  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   417  		})
   418  
   419  		It("Pipeline timeouts", func() {
   420  			_, err := client.Pipelined(ctx, func(pipe redis.Pipeliner) error {
   421  				pipe.Ping(ctx)
   422  				return nil
   423  			})
   424  			Expect(err).To(HaveOccurred())
   425  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   426  		})
   427  
   428  		It("Subscribe timeouts", func() {
   429  			if opt.WriteTimeout == 0 {
   430  				return
   431  			}
   432  
   433  			pubsub := client.Subscribe(ctx)
   434  			defer pubsub.Close()
   435  
   436  			err := pubsub.Subscribe(ctx, "_")
   437  			Expect(err).To(HaveOccurred())
   438  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   439  		})
   440  
   441  		It("Tx timeouts", func() {
   442  			err := client.Watch(ctx, func(tx *redis.Tx) error {
   443  				return tx.Ping(ctx).Err()
   444  			})
   445  			Expect(err).To(HaveOccurred())
   446  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   447  		})
   448  
   449  		It("Tx Pipeline timeouts", func() {
   450  			err := client.Watch(ctx, func(tx *redis.Tx) error {
   451  				_, err := tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
   452  					pipe.Ping(ctx)
   453  					return nil
   454  				})
   455  				return err
   456  			})
   457  			Expect(err).To(HaveOccurred())
   458  			Expect(err.(net.Error).Timeout()).To(BeTrue())
   459  		})
   460  	}
   461  
   462  	Context("read timeout", func() {
   463  		BeforeEach(func() {
   464  			opt = redisOptions()
   465  			opt.ReadTimeout = time.Nanosecond
   466  			opt.WriteTimeout = -1
   467  			client = redis.NewClient(opt)
   468  		})
   469  
   470  		testTimeout()
   471  	})
   472  
   473  	Context("write timeout", func() {
   474  		BeforeEach(func() {
   475  			opt = redisOptions()
   476  			opt.ReadTimeout = -1
   477  			opt.WriteTimeout = time.Nanosecond
   478  			client = redis.NewClient(opt)
   479  		})
   480  
   481  		testTimeout()
   482  	})
   483  })
   484  
   485  var _ = Describe("Client OnConnect", func() {
   486  	var client *redis.Client
   487  
   488  	BeforeEach(func() {
   489  		opt := redisOptions()
   490  		opt.DB = 0
   491  		opt.OnConnect = func(ctx context.Context, cn *redis.Conn) error {
   492  			return cn.ClientSetName(ctx, "on_connect").Err()
   493  		}
   494  
   495  		client = redis.NewClient(opt)
   496  	})
   497  
   498  	AfterEach(func() {
   499  		Expect(client.Close()).NotTo(HaveOccurred())
   500  	})
   501  
   502  	It("calls OnConnect", func() {
   503  		name, err := client.ClientGetName(ctx).Result()
   504  		Expect(err).NotTo(HaveOccurred())
   505  		Expect(name).To(Equal("on_connect"))
   506  	})
   507  })
   508  
   509  var _ = Describe("Client context cancellation", func() {
   510  	var opt *redis.Options
   511  	var client *redis.Client
   512  
   513  	BeforeEach(func() {
   514  		opt = redisOptions()
   515  		opt.ReadTimeout = -1
   516  		opt.WriteTimeout = -1
   517  		client = redis.NewClient(opt)
   518  	})
   519  
   520  	AfterEach(func() {
   521  		Expect(client.Close()).NotTo(HaveOccurred())
   522  	})
   523  
   524  	It("Blocking operation cancellation", func() {
   525  		ctx, cancel := context.WithCancel(ctx)
   526  		cancel()
   527  
   528  		err := client.BLPop(ctx, 1*time.Second, "test").Err()
   529  		Expect(err).To(HaveOccurred())
   530  		Expect(err).To(BeIdenticalTo(context.Canceled))
   531  	})
   532  })
   533  
   534  var _ = Describe("Conn", func() {
   535  	var client *redis.Client
   536  
   537  	BeforeEach(func() {
   538  		client = redis.NewClient(redisOptions())
   539  		Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
   540  	})
   541  
   542  	AfterEach(func() {
   543  		err := client.Close()
   544  		Expect(err).NotTo(HaveOccurred())
   545  	})
   546  
   547  	It("TxPipeline", Label("NonRedisEnterprise"), func() {
   548  		tx := client.Conn().TxPipeline()
   549  		tx.SwapDB(ctx, 0, 2)
   550  		tx.SwapDB(ctx, 1, 0)
   551  		_, err := tx.Exec(ctx)
   552  		Expect(err).NotTo(HaveOccurred())
   553  	})
   554  })
   555  
   556  var _ = Describe("Hook", func() {
   557  	var client *redis.Client
   558  
   559  	BeforeEach(func() {
   560  		client = redis.NewClient(redisOptions())
   561  		Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
   562  	})
   563  
   564  	AfterEach(func() {
   565  		err := client.Close()
   566  		Expect(err).NotTo(HaveOccurred())
   567  	})
   568  
   569  	It("fifo", func() {
   570  		var res []string
   571  		client.AddHook(&hook{
   572  			processHook: func(hook redis.ProcessHook) redis.ProcessHook {
   573  				return func(ctx context.Context, cmd redis.Cmder) error {
   574  					res = append(res, "hook-1-process-start")
   575  					err := hook(ctx, cmd)
   576  					res = append(res, "hook-1-process-end")
   577  					return err
   578  				}
   579  			},
   580  		})
   581  		client.AddHook(&hook{
   582  			processHook: func(hook redis.ProcessHook) redis.ProcessHook {
   583  				return func(ctx context.Context, cmd redis.Cmder) error {
   584  					res = append(res, "hook-2-process-start")
   585  					err := hook(ctx, cmd)
   586  					res = append(res, "hook-2-process-end")
   587  					return err
   588  				}
   589  			},
   590  		})
   591  
   592  		err := client.Ping(ctx).Err()
   593  		Expect(err).NotTo(HaveOccurred())
   594  
   595  		Expect(res).To(Equal([]string{
   596  			"hook-1-process-start",
   597  			"hook-2-process-start",
   598  			"hook-2-process-end",
   599  			"hook-1-process-end",
   600  		}))
   601  	})
   602  
   603  	It("wrapped error in a hook", func() {
   604  		client.AddHook(&hook{
   605  			processHook: func(hook redis.ProcessHook) redis.ProcessHook {
   606  				return func(ctx context.Context, cmd redis.Cmder) error {
   607  					if err := hook(ctx, cmd); err != nil {
   608  						return fmt.Errorf("wrapped error: %w", err)
   609  					}
   610  					return nil
   611  				}
   612  			},
   613  		})
   614  		client.ScriptFlush(ctx)
   615  
   616  		script := redis.NewScript(`return 'Script and hook'`)
   617  
   618  		cmd := script.Run(ctx, client, nil)
   619  		Expect(cmd.Err()).NotTo(HaveOccurred())
   620  		Expect(cmd.Val()).To(Equal("Script and hook"))
   621  	})
   622  })
   623  
   624  var _ = Describe("Hook with MinIdleConns", func() {
   625  	var client *redis.Client
   626  
   627  	BeforeEach(func() {
   628  		options := redisOptions()
   629  		options.MinIdleConns = 1
   630  		client = redis.NewClient(options)
   631  		Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
   632  	})
   633  
   634  	AfterEach(func() {
   635  		err := client.Close()
   636  		Expect(err).NotTo(HaveOccurred())
   637  	})
   638  
   639  	It("fifo", func() {
   640  		var res []string
   641  		client.AddHook(&hook{
   642  			processHook: func(hook redis.ProcessHook) redis.ProcessHook {
   643  				return func(ctx context.Context, cmd redis.Cmder) error {
   644  					res = append(res, "hook-1-process-start")
   645  					err := hook(ctx, cmd)
   646  					res = append(res, "hook-1-process-end")
   647  					return err
   648  				}
   649  			},
   650  		})
   651  		client.AddHook(&hook{
   652  			processHook: func(hook redis.ProcessHook) redis.ProcessHook {
   653  				return func(ctx context.Context, cmd redis.Cmder) error {
   654  					res = append(res, "hook-2-process-start")
   655  					err := hook(ctx, cmd)
   656  					res = append(res, "hook-2-process-end")
   657  					return err
   658  				}
   659  			},
   660  		})
   661  
   662  		err := client.Ping(ctx).Err()
   663  		Expect(err).NotTo(HaveOccurred())
   664  
   665  		Expect(res).To(Equal([]string{
   666  			"hook-1-process-start",
   667  			"hook-2-process-start",
   668  			"hook-2-process-end",
   669  			"hook-1-process-end",
   670  		}))
   671  	})
   672  })
   673  
   674  var _ = Describe("Dialer connection timeouts", func() {
   675  	var client *redis.Client
   676  
   677  	const dialSimulatedDelay = 1 * time.Second
   678  
   679  	BeforeEach(func() {
   680  		options := redisOptions()
   681  		options.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
   682  			// Simulated slow dialer.
   683  			// Note that the following sleep is deliberately not context-aware.
   684  			time.Sleep(dialSimulatedDelay)
   685  			return net.Dial("tcp", options.Addr)
   686  		}
   687  		options.MinIdleConns = 1
   688  		client = redis.NewClient(options)
   689  	})
   690  
   691  	AfterEach(func() {
   692  		err := client.Close()
   693  		Expect(err).NotTo(HaveOccurred())
   694  	})
   695  
   696  	It("does not contend on connection dial for concurrent commands", func() {
   697  		var wg sync.WaitGroup
   698  
   699  		const concurrency = 10
   700  
   701  		durations := make(chan time.Duration, concurrency)
   702  		errs := make(chan error, concurrency)
   703  
   704  		start := time.Now()
   705  		wg.Add(concurrency)
   706  
   707  		for i := 0; i < concurrency; i++ {
   708  			go func() {
   709  				defer wg.Done()
   710  
   711  				start := time.Now()
   712  				err := client.Ping(ctx).Err()
   713  				durations <- time.Since(start)
   714  				errs <- err
   715  			}()
   716  		}
   717  
   718  		wg.Wait()
   719  		close(durations)
   720  		close(errs)
   721  
   722  		// All commands should eventually succeed, after acquiring a connection.
   723  		for err := range errs {
   724  			Expect(err).NotTo(HaveOccurred())
   725  		}
   726  
   727  		// Each individual command should complete within the simulated dial duration bound.
   728  		for duration := range durations {
   729  			Expect(duration).To(BeNumerically("<", 2*dialSimulatedDelay))
   730  		}
   731  
   732  		// Due to concurrent execution, the entire test suite should also complete within
   733  		// the same dial duration bound applied for individual commands.
   734  		Expect(time.Since(start)).To(BeNumerically("<", 2*dialSimulatedDelay))
   735  	})
   736  })
   737  
   738  var _ = Describe("Credentials Provider Priority", func() {
   739  	var client *redis.Client
   740  	var opt *redis.Options
   741  	var recorder *commandRecorder
   742  
   743  	BeforeEach(func() {
   744  		recorder = newCommandRecorder(10)
   745  	})
   746  
   747  	AfterEach(func() {
   748  		if client != nil {
   749  			Expect(client.Close()).NotTo(HaveOccurred())
   750  		}
   751  	})
   752  
   753  	It("should use streaming provider when available", func() {
   754  		streamingCreds := auth.NewBasicCredentials("streaming_user", "streaming_pass")
   755  		ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
   756  		providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
   757  
   758  		opt = &redis.Options{
   759  			Username: "field_user",
   760  			Password: "field_pass",
   761  			CredentialsProvider: func() (string, string) {
   762  				username, password := providerCreds.BasicAuth()
   763  				return username, password
   764  			},
   765  			CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
   766  				username, password := ctxCreds.BasicAuth()
   767  				return username, password, nil
   768  			},
   769  			StreamingCredentialsProvider: &mockStreamingProvider{
   770  				credentials: streamingCreds,
   771  				updates:     make(chan auth.Credentials, 1),
   772  			},
   773  		}
   774  
   775  		client = redis.NewClient(opt)
   776  		client.AddHook(recorder.Hook())
   777  		// wrongpass
   778  		Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
   779  		Expect(recorder.Contains("AUTH streaming_user")).To(BeTrue())
   780  	})
   781  
   782  	It("should use context provider when streaming provider is not available", func() {
   783  		ctxCreds := auth.NewBasicCredentials("ctx_user", "ctx_pass")
   784  		providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
   785  
   786  		opt = &redis.Options{
   787  			Username: "field_user",
   788  			Password: "field_pass",
   789  			CredentialsProvider: func() (string, string) {
   790  				username, password := providerCreds.BasicAuth()
   791  				return username, password
   792  			},
   793  			CredentialsProviderContext: func(ctx context.Context) (string, string, error) {
   794  				username, password := ctxCreds.BasicAuth()
   795  				return username, password, nil
   796  			},
   797  		}
   798  
   799  		client = redis.NewClient(opt)
   800  		client.AddHook(recorder.Hook())
   801  		// wrongpass
   802  		Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
   803  		Expect(recorder.Contains("AUTH ctx_user")).To(BeTrue())
   804  	})
   805  
   806  	It("should use regular provider when streaming and context providers are not available", func() {
   807  		providerCreds := auth.NewBasicCredentials("provider_user", "provider_pass")
   808  
   809  		opt = &redis.Options{
   810  			Username: "field_user",
   811  			Password: "field_pass",
   812  			CredentialsProvider: func() (string, string) {
   813  				username, password := providerCreds.BasicAuth()
   814  				return username, password
   815  			},
   816  		}
   817  
   818  		client = redis.NewClient(opt)
   819  		client.AddHook(recorder.Hook())
   820  		// wrongpass
   821  		Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
   822  		Expect(recorder.Contains("AUTH provider_user")).To(BeTrue())
   823  	})
   824  
   825  	It("should use username/password fields when no providers are set", func() {
   826  		opt = &redis.Options{
   827  			Username: "field_user",
   828  			Password: "field_pass",
   829  		}
   830  
   831  		client = redis.NewClient(opt)
   832  		client.AddHook(recorder.Hook())
   833  		// wrongpass
   834  		Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
   835  		Expect(recorder.Contains("AUTH field_user")).To(BeTrue())
   836  	})
   837  
   838  	It("should use empty credentials when nothing is set", func() {
   839  		opt = &redis.Options{}
   840  
   841  		client = redis.NewClient(opt)
   842  		client.AddHook(recorder.Hook())
   843  		// no pass, ok
   844  		Expect(client.Ping(context.Background()).Err()).NotTo(HaveOccurred())
   845  		Expect(recorder.Contains("AUTH")).To(BeFalse())
   846  	})
   847  
   848  	It("should handle credential updates from streaming provider", func() {
   849  		initialCreds := auth.NewBasicCredentials("initial_user", "initial_pass")
   850  		updatedCreds := auth.NewBasicCredentials("updated_user", "updated_pass")
   851  		updatesChan := make(chan auth.Credentials, 1)
   852  
   853  		opt = &redis.Options{
   854  			StreamingCredentialsProvider: &mockStreamingProvider{
   855  				credentials: initialCreds,
   856  				updates:     updatesChan,
   857  			},
   858  		}
   859  
   860  		client = redis.NewClient(opt)
   861  		client.AddHook(recorder.Hook())
   862  		// wrongpass
   863  		Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
   864  		Expect(recorder.Contains("AUTH initial_user")).To(BeTrue())
   865  
   866  		// Update credentials
   867  		opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds
   868  		// wrongpass
   869  		Expect(client.Ping(context.Background()).Err()).To(HaveOccurred())
   870  		Expect(recorder.Contains("AUTH updated_user")).To(BeTrue())
   871  		close(updatesChan)
   872  	})
   873  })
   874  
   875  type mockStreamingProvider struct {
   876  	credentials auth.Credentials
   877  	err         error
   878  	updates     chan auth.Credentials
   879  }
   880  
   881  func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
   882  	if m.err != nil {
   883  		return nil, nil, m.err
   884  	}
   885  
   886  	// Start goroutine to handle updates
   887  	go func() {
   888  		for creds := range m.updates {
   889  			m.credentials = creds
   890  			listener.OnNext(creds)
   891  		}
   892  	}()
   893  
   894  	return m.credentials, func() (err error) {
   895  		defer func() {
   896  			if r := recover(); r != nil {
   897  				// this is just a mock:
   898  				// allow multiple closes from multiple listeners
   899  			}
   900  		}()
   901  		return
   902  	}, nil
   903  }
   904  
   905  var _ = Describe("Client creation", func() {
   906  	Context("simple client with nil options", func() {
   907  		It("panics", func() {
   908  			Expect(func() {
   909  				redis.NewClient(nil)
   910  			}).To(Panic())
   911  		})
   912  	})
   913  	Context("cluster client with nil options", func() {
   914  		It("panics", func() {
   915  			Expect(func() {
   916  				redis.NewClusterClient(nil)
   917  			}).To(Panic())
   918  		})
   919  	})
   920  	Context("ring client with nil options", func() {
   921  		It("panics", func() {
   922  			Expect(func() {
   923  				redis.NewRing(nil)
   924  			}).To(Panic())
   925  		})
   926  	})
   927  	Context("universal client with nil options", func() {
   928  		It("panics", func() {
   929  			Expect(func() {
   930  				redis.NewUniversalClient(nil)
   931  			}).To(Panic())
   932  		})
   933  	})
   934  	Context("failover client with nil options", func() {
   935  		It("panics", func() {
   936  			Expect(func() {
   937  				redis.NewFailoverClient(nil)
   938  			}).To(Panic())
   939  		})
   940  	})
   941  	Context("failover cluster client with nil options", func() {
   942  		It("panics", func() {
   943  			Expect(func() {
   944  				redis.NewFailoverClusterClient(nil)
   945  			}).To(Panic())
   946  		})
   947  	})
   948  	Context("sentinel client with nil options", func() {
   949  		It("panics", func() {
   950  			Expect(func() {
   951  				redis.NewSentinelClient(nil)
   952  			}).To(Panic())
   953  		})
   954  	})
   955  })
   956  

View as plain text