...

Source file src/github.com/redis/go-redis/v9/internal_test.go

Documentation: github.com/redis/go-redis/v9

     1  package redis
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/redis/go-redis/v9/internal/pool"
    13  	"github.com/redis/go-redis/v9/internal/proto"
    14  
    15  	. "github.com/bsm/ginkgo/v2"
    16  	. "github.com/bsm/gomega"
    17  )
    18  
    19  var _ = Describe("newClusterState", func() {
    20  	var state *clusterState
    21  
    22  	createClusterState := func(slots []ClusterSlot) *clusterState {
    23  		opt := &ClusterOptions{}
    24  		opt.init()
    25  		nodes := newClusterNodes(opt)
    26  		state, err := newClusterState(nodes, slots, "10.10.10.10:1234")
    27  		Expect(err).NotTo(HaveOccurred())
    28  		return state
    29  	}
    30  
    31  	Describe("sorting", func() {
    32  		BeforeEach(func() {
    33  			state = createClusterState([]ClusterSlot{{
    34  				Start: 1000,
    35  				End:   1999,
    36  			}, {
    37  				Start: 0,
    38  				End:   999,
    39  			}, {
    40  				Start: 2000,
    41  				End:   2999,
    42  			}})
    43  		})
    44  
    45  		It("sorts slots", func() {
    46  			Expect(state.slots).To(Equal([]*clusterSlot{
    47  				{start: 0, end: 999, nodes: nil},
    48  				{start: 1000, end: 1999, nodes: nil},
    49  				{start: 2000, end: 2999, nodes: nil},
    50  			}))
    51  		})
    52  	})
    53  
    54  	Describe("loopback", func() {
    55  		BeforeEach(func() {
    56  			state = createClusterState([]ClusterSlot{{
    57  				Nodes: []ClusterNode{{Addr: "127.0.0.1:7001"}},
    58  			}, {
    59  				Nodes: []ClusterNode{{Addr: "127.0.0.1:7002"}},
    60  			}, {
    61  				Nodes: []ClusterNode{{Addr: "1.2.3.4:1234"}},
    62  			}, {
    63  				Nodes: []ClusterNode{{Addr: ":1234"}},
    64  			}})
    65  		})
    66  
    67  		It("replaces loopback hosts in addresses", func() {
    68  			slotAddr := func(slot *clusterSlot) string {
    69  				return slot.nodes[0].Client.Options().Addr
    70  			}
    71  
    72  			Expect(slotAddr(state.slots[0])).To(Equal("10.10.10.10:7001"))
    73  			Expect(slotAddr(state.slots[1])).To(Equal("10.10.10.10:7002"))
    74  			Expect(slotAddr(state.slots[2])).To(Equal("1.2.3.4:1234"))
    75  			Expect(slotAddr(state.slots[3])).To(Equal(":1234"))
    76  		})
    77  	})
    78  })
    79  
    80  type fixedHash string
    81  
    82  func (h fixedHash) Get(string) string {
    83  	return string(h)
    84  }
    85  
    86  func TestRingSetAddrsAndRebalanceRace(t *testing.T) {
    87  	const (
    88  		ringShard1Name = "ringShardOne"
    89  		ringShard2Name = "ringShardTwo"
    90  
    91  		ringShard1Port = "6390"
    92  		ringShard2Port = "6391"
    93  	)
    94  
    95  	ring := NewRing(&RingOptions{
    96  		Addrs: map[string]string{
    97  			ringShard1Name: ":" + ringShard1Port,
    98  		},
    99  		// Disable heartbeat
   100  		HeartbeatFrequency: 1 * time.Hour,
   101  		NewConsistentHash: func(shards []string) ConsistentHash {
   102  			switch len(shards) {
   103  			case 1:
   104  				return fixedHash(ringShard1Name)
   105  			case 2:
   106  				return fixedHash(ringShard2Name)
   107  			default:
   108  				t.Fatalf("Unexpected number of shards: %v", shards)
   109  				return nil
   110  			}
   111  		},
   112  	})
   113  	defer ring.Close()
   114  
   115  	// Continuously update addresses by adding and removing one address
   116  	updatesDone := make(chan struct{})
   117  	defer func() { close(updatesDone) }()
   118  	go func() {
   119  		for i := 0; ; i++ {
   120  			select {
   121  			case <-updatesDone:
   122  				return
   123  			default:
   124  				if i%2 == 0 {
   125  					ring.SetAddrs(map[string]string{
   126  						ringShard1Name: ":" + ringShard1Port,
   127  					})
   128  				} else {
   129  					ring.SetAddrs(map[string]string{
   130  						ringShard1Name: ":" + ringShard1Port,
   131  						ringShard2Name: ":" + ringShard2Port,
   132  					})
   133  				}
   134  			}
   135  		}
   136  	}()
   137  
   138  	timer := time.NewTimer(1 * time.Second)
   139  	for running := true; running; {
   140  		select {
   141  		case <-timer.C:
   142  			running = false
   143  		default:
   144  			shard, err := ring.sharding.GetByKey("whatever")
   145  			if err == nil && shard == nil {
   146  				t.Fatal("shard is nil")
   147  			}
   148  		}
   149  	}
   150  }
   151  
   152  func BenchmarkRingShardingRebalanceLocked(b *testing.B) {
   153  	opts := &RingOptions{
   154  		Addrs: make(map[string]string),
   155  		// Disable heartbeat
   156  		HeartbeatFrequency: 1 * time.Hour,
   157  	}
   158  	for i := 0; i < 100; i++ {
   159  		opts.Addrs[fmt.Sprintf("shard%d", i)] = fmt.Sprintf(":63%02d", i)
   160  	}
   161  
   162  	ring := NewRing(opts)
   163  	defer ring.Close()
   164  
   165  	b.ResetTimer()
   166  	for i := 0; i < b.N; i++ {
   167  		ring.sharding.rebalanceLocked()
   168  	}
   169  }
   170  
   171  type testCounter struct {
   172  	mu sync.Mutex
   173  	t  *testing.T
   174  	m  map[string]int
   175  }
   176  
   177  func newTestCounter(t *testing.T) *testCounter {
   178  	return &testCounter{t: t, m: make(map[string]int)}
   179  }
   180  
   181  func (ct *testCounter) increment(key string) {
   182  	ct.mu.Lock()
   183  	defer ct.mu.Unlock()
   184  	ct.m[key]++
   185  }
   186  
   187  func (ct *testCounter) expect(values map[string]int) {
   188  	ct.mu.Lock()
   189  	defer ct.mu.Unlock()
   190  	ct.t.Helper()
   191  	if !reflect.DeepEqual(values, ct.m) {
   192  		ct.t.Errorf("expected %v != actual %v", values, ct.m)
   193  	}
   194  }
   195  
   196  func TestRingShardsCleanup(t *testing.T) {
   197  	const (
   198  		ringShard1Name = "ringShardOne"
   199  		ringShard2Name = "ringShardTwo"
   200  
   201  		ringShard1Addr = "shard1.test"
   202  		ringShard2Addr = "shard2.test"
   203  	)
   204  
   205  	t.Run("closes unused shards", func(t *testing.T) {
   206  		closeCounter := newTestCounter(t)
   207  
   208  		ring := NewRing(&RingOptions{
   209  			Addrs: map[string]string{
   210  				ringShard1Name: ringShard1Addr,
   211  				ringShard2Name: ringShard2Addr,
   212  			},
   213  			NewClient: func(opt *Options) *Client {
   214  				c := NewClient(opt)
   215  				c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
   216  					closeCounter.increment(opt.Addr)
   217  					return nil
   218  				})
   219  				return c
   220  			},
   221  		})
   222  		closeCounter.expect(map[string]int{})
   223  
   224  		// no change due to the same addresses
   225  		ring.SetAddrs(map[string]string{
   226  			ringShard1Name: ringShard1Addr,
   227  			ringShard2Name: ringShard2Addr,
   228  		})
   229  		closeCounter.expect(map[string]int{})
   230  
   231  		ring.SetAddrs(map[string]string{
   232  			ringShard1Name: ringShard1Addr,
   233  		})
   234  		closeCounter.expect(map[string]int{ringShard2Addr: 1})
   235  
   236  		ring.SetAddrs(map[string]string{
   237  			ringShard2Name: ringShard2Addr,
   238  		})
   239  		closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
   240  
   241  		ring.Close()
   242  		closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 2})
   243  	})
   244  
   245  	t.Run("closes created shards if ring was closed", func(t *testing.T) {
   246  		createCounter := newTestCounter(t)
   247  		closeCounter := newTestCounter(t)
   248  
   249  		var (
   250  			ring        *Ring
   251  			shouldClose int32
   252  		)
   253  
   254  		ring = NewRing(&RingOptions{
   255  			Addrs: map[string]string{
   256  				ringShard1Name: ringShard1Addr,
   257  			},
   258  			NewClient: func(opt *Options) *Client {
   259  				if atomic.LoadInt32(&shouldClose) != 0 {
   260  					ring.Close()
   261  				}
   262  				createCounter.increment(opt.Addr)
   263  				c := NewClient(opt)
   264  				c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
   265  					closeCounter.increment(opt.Addr)
   266  					return nil
   267  				})
   268  				return c
   269  			},
   270  		})
   271  		createCounter.expect(map[string]int{ringShard1Addr: 1})
   272  		closeCounter.expect(map[string]int{})
   273  
   274  		atomic.StoreInt32(&shouldClose, 1)
   275  
   276  		ring.SetAddrs(map[string]string{
   277  			ringShard2Name: ringShard2Addr,
   278  		})
   279  		createCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
   280  		closeCounter.expect(map[string]int{ringShard1Addr: 1, ringShard2Addr: 1})
   281  	})
   282  }
   283  
   284  //------------------------------------------------------------------------------
   285  
   286  type timeoutErr struct {
   287  	error
   288  }
   289  
   290  func (e timeoutErr) Timeout() bool {
   291  	return true
   292  }
   293  
   294  func (e timeoutErr) Temporary() bool {
   295  	return true
   296  }
   297  
   298  func (e timeoutErr) Error() string {
   299  	return "i/o timeout"
   300  }
   301  
   302  var _ = Describe("withConn", func() {
   303  	var client *Client
   304  
   305  	BeforeEach(func() {
   306  		client = NewClient(&Options{
   307  			PoolSize: 1,
   308  		})
   309  	})
   310  
   311  	AfterEach(func() {
   312  		client.Close()
   313  	})
   314  
   315  	It("should replace the connection in the pool when there is no error", func() {
   316  		var conn *pool.Conn
   317  
   318  		client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
   319  			conn = c
   320  			return nil
   321  		})
   322  
   323  		newConn, err := client.connPool.Get(ctx)
   324  		Expect(err).To(BeNil())
   325  		Expect(newConn).To(Equal(conn))
   326  	})
   327  
   328  	It("should replace the connection in the pool when there is an error not related to a bad connection", func() {
   329  		var conn *pool.Conn
   330  
   331  		client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
   332  			conn = c
   333  			return proto.RedisError("LOADING")
   334  		})
   335  
   336  		newConn, err := client.connPool.Get(ctx)
   337  		Expect(err).To(BeNil())
   338  		Expect(newConn).To(Equal(conn))
   339  	})
   340  
   341  	It("should remove the connection from the pool when it times out", func() {
   342  		var conn *pool.Conn
   343  
   344  		client.withConn(ctx, func(ctx context.Context, c *pool.Conn) error {
   345  			conn = c
   346  			return timeoutErr{}
   347  		})
   348  
   349  		newConn, err := client.connPool.Get(ctx)
   350  		Expect(err).To(BeNil())
   351  		Expect(newConn).NotTo(Equal(conn))
   352  		Expect(client.connPool.Len()).To(Equal(1))
   353  	})
   354  })
   355  
   356  var _ = Describe("ClusterClient", func() {
   357  	var client *ClusterClient
   358  
   359  	BeforeEach(func() {
   360  		client = &ClusterClient{}
   361  	})
   362  
   363  	Describe("cmdSlot", func() {
   364  		It("select slot from args for GETKEYSINSLOT command", func() {
   365  			cmd := NewStringSliceCmd(ctx, "cluster", "getkeysinslot", 100, 200)
   366  
   367  			slot := client.cmdSlot(cmd, -1)
   368  			Expect(slot).To(Equal(100))
   369  		})
   370  
   371  		It("select slot from args for COUNTKEYSINSLOT command", func() {
   372  			cmd := NewStringSliceCmd(ctx, "cluster", "countkeysinslot", 100)
   373  
   374  			slot := client.cmdSlot(cmd, -1)
   375  			Expect(slot).To(Equal(100))
   376  		})
   377  
   378  		It("follows preferred random slot", func() {
   379  			cmd := NewStatusCmd(ctx, "ping")
   380  
   381  			slot := client.cmdSlot(cmd, 101)
   382  			Expect(slot).To(Equal(101))
   383  		})
   384  	})
   385  })
   386  
   387  var _ = Describe("isLoopback", func() {
   388  	DescribeTable("should correctly identify loopback addresses",
   389  		func(host string, expected bool) {
   390  			result := isLoopback(host)
   391  			Expect(result).To(Equal(expected))
   392  		},
   393  		// IP addresses
   394  		Entry("IPv4 loopback", "127.0.0.1", true),
   395  		Entry("IPv6 loopback", "::1", true),
   396  		Entry("IPv4 non-loopback", "192.168.1.1", false),
   397  		Entry("IPv6 non-loopback", "2001:db8::1", false),
   398  
   399  		// Well-known loopback hostnames
   400  		Entry("localhost lowercase", "localhost", true),
   401  		Entry("localhost uppercase", "LOCALHOST", true),
   402  		Entry("localhost mixed case", "LocalHost", true),
   403  
   404  		// Docker-specific loopbacks
   405  		Entry("host.docker.internal", "host.docker.internal", true),
   406  		Entry("HOST.DOCKER.INTERNAL", "HOST.DOCKER.INTERNAL", true),
   407  		Entry("custom.docker.internal", "custom.docker.internal", true),
   408  		Entry("app.docker.internal", "app.docker.internal", true),
   409  
   410  		// Non-loopback hostnames
   411  		Entry("redis hostname", "redis-cluster", false),
   412  		Entry("FQDN", "redis.example.com", false),
   413  		Entry("docker but not internal", "redis.docker.com", false),
   414  
   415  		// Edge cases
   416  		Entry("empty string", "", false),
   417  		Entry("invalid IP", "256.256.256.256", false),
   418  		Entry("partial docker internal", "docker.internal", false),
   419  	)
   420  })
   421  

View as plain text