...

Source file src/github.com/redis/go-redis/v9/vectorset_commands_integration_test.go

Documentation: github.com/redis/go-redis/v9

     1  package redis_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"time"
     8  
     9  	. "github.com/bsm/ginkgo/v2"
    10  	. "github.com/bsm/gomega"
    11  	"github.com/redis/go-redis/v9"
    12  	"github.com/redis/go-redis/v9/internal/proto"
    13  )
    14  
    15  func expectNil(err error) {
    16  	Expect(err).NotTo(HaveOccurred())
    17  }
    18  
    19  func expectTrue(t bool) {
    20  	expectEqual(t, true)
    21  }
    22  
    23  func expectEqual[T any, U any](a T, b U) {
    24  	Expect(a).To(BeEquivalentTo(b))
    25  }
    26  
    27  func generateRandomVector(dim int) redis.VectorValues {
    28  	rand.Seed(time.Now().UnixNano())
    29  	v := make([]float64, dim)
    30  	for i := range v {
    31  		v[i] = float64(rand.Intn(1000)) + rand.Float64()
    32  	}
    33  	return redis.VectorValues{Val: v}
    34  }
    35  
    36  var _ = Describe("Redis VectorSet commands", Label("vectorset"), func() {
    37  	ctx := context.TODO()
    38  
    39  	setupRedisClient := func(protocolVersion int) *redis.Client {
    40  		return redis.NewClient(&redis.Options{
    41  			Addr:          "localhost:6379",
    42  			DB:            0,
    43  			Protocol:      protocolVersion,
    44  			UnstableResp3: true,
    45  		})
    46  	}
    47  
    48  	protocols := []int{2, 3}
    49  	for _, protocol := range protocols {
    50  		protocol := protocol
    51  
    52  		Context(fmt.Sprintf("with protocol version %d", protocol), func() {
    53  			var client *redis.Client
    54  
    55  			BeforeEach(func() {
    56  				client = setupRedisClient(protocol)
    57  				Expect(client.FlushAll(ctx).Err()).NotTo(HaveOccurred())
    58  			})
    59  
    60  			AfterEach(func() {
    61  				if client != nil {
    62  					client.FlushDB(ctx)
    63  					client.Close()
    64  				}
    65  			})
    66  
    67  			It("basic", func() {
    68  				SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet")
    69  				vecName := "basic"
    70  				val := &redis.VectorValues{
    71  					Val: []float64{1.5, 2.4, 3.3, 4.2},
    72  				}
    73  				ok, err := client.VAdd(ctx, vecName, "k1", val).Result()
    74  				expectNil(err)
    75  				expectTrue(ok)
    76  
    77  				fp32 := "\x8f\xc2\xf9\x3e\xcb\xbe\xe9\xbe\xb0\x1e\xca\x3f\x5e\x06\x9e\x3f"
    78  				val2 := &redis.VectorFP32{
    79  					Val: []byte(fp32),
    80  				}
    81  				ok, err = client.VAdd(ctx, vecName, "k2", val2).Result()
    82  				expectNil(err)
    83  				expectTrue(ok)
    84  
    85  				dim, err := client.VDim(ctx, vecName).Result()
    86  				expectNil(err)
    87  				expectEqual(dim, 4)
    88  
    89  				count, err := client.VCard(ctx, vecName).Result()
    90  				expectNil(err)
    91  				expectEqual(count, 2)
    92  
    93  				ok, err = client.VRem(ctx, vecName, "k1").Result()
    94  				expectNil(err)
    95  				expectTrue(ok)
    96  
    97  				count, err = client.VCard(ctx, vecName).Result()
    98  				expectNil(err)
    99  				expectEqual(count, 1)
   100  			})
   101  
   102  			It("basic similarity", func() {
   103  				SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet")
   104  				vecName := "basic_similarity"
   105  
   106  				ok, err := client.VAdd(ctx, vecName, "k1", &redis.VectorValues{
   107  					Val: []float64{1, 0, 0, 0},
   108  				}).Result()
   109  				expectNil(err)
   110  				expectTrue(ok)
   111  				ok, err = client.VAdd(ctx, vecName, "k2", &redis.VectorValues{
   112  					Val: []float64{0.99, 0.01, 0, 0},
   113  				}).Result()
   114  				expectNil(err)
   115  				expectTrue(ok)
   116  				ok, err = client.VAdd(ctx, vecName, "k3", &redis.VectorValues{
   117  					Val: []float64{0.1, 1, -1, 0.5},
   118  				}).Result()
   119  				expectNil(err)
   120  				expectTrue(ok)
   121  
   122  				sim, err := client.VSimWithScores(ctx, vecName, &redis.VectorValues{
   123  					Val: []float64{1, 0, 0, 0},
   124  				}).Result()
   125  				expectNil(err)
   126  				expectEqual(len(sim), 3)
   127  				simMap := make(map[string]float64)
   128  				for _, vi := range sim {
   129  					simMap[vi.Name] = vi.Score
   130  				}
   131  				expectTrue(simMap["k1"] > 0.99)
   132  				expectTrue(simMap["k2"] > 0.99)
   133  				expectTrue(simMap["k3"] < 0.8)
   134  			})
   135  
   136  			It("dimension operation", func() {
   137  				SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet")
   138  				vecName := "dimension_op"
   139  				originalDim := 100
   140  				reducedDim := 50
   141  
   142  				v1 := generateRandomVector(originalDim)
   143  				ok, err := client.VAddWithArgs(ctx, vecName, "k1", &v1, &redis.VAddArgs{
   144  					Reduce: int64(reducedDim),
   145  				}).Result()
   146  				expectNil(err)
   147  				expectTrue(ok)
   148  
   149  				info, err := client.VInfo(ctx, vecName).Result()
   150  				expectNil(err)
   151  				dim := info["vector-dim"].(int64)
   152  				oriDim := info["projection-input-dim"].(int64)
   153  				expectEqual(dim, reducedDim)
   154  				expectEqual(oriDim, originalDim)
   155  
   156  				wrongDim := 80
   157  				wrongV := generateRandomVector(wrongDim)
   158  				_, err = client.VAddWithArgs(ctx, vecName, "kw", &wrongV, &redis.VAddArgs{
   159  					Reduce: int64(reducedDim),
   160  				}).Result()
   161  				expectTrue(err != nil)
   162  
   163  				v2 := generateRandomVector(originalDim)
   164  				ok, err = client.VAddWithArgs(ctx, vecName, "k2", &v2, &redis.VAddArgs{
   165  					Reduce: int64(reducedDim),
   166  				}).Result()
   167  				expectNil(err)
   168  				expectTrue(ok)
   169  			})
   170  
   171  			It("remove", func() {
   172  				SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet")
   173  				vecName := "remove"
   174  				v1 := generateRandomVector(5)
   175  				ok, err := client.VAdd(ctx, vecName, "k1", &v1).Result()
   176  				expectNil(err)
   177  				expectTrue(ok)
   178  
   179  				exist, err := client.Exists(ctx, vecName).Result()
   180  				expectNil(err)
   181  				expectEqual(exist, 1)
   182  
   183  				ok, err = client.VRem(ctx, vecName, "k1").Result()
   184  				expectNil(err)
   185  				expectTrue(ok)
   186  
   187  				exist, err = client.Exists(ctx, vecName).Result()
   188  				expectNil(err)
   189  				expectEqual(exist, 0)
   190  			})
   191  
   192  			It("all operations", func() {
   193  				SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet")
   194  				vecName := "commands"
   195  				vals := []struct {
   196  					name string
   197  					v    redis.VectorValues
   198  					attr string
   199  				}{
   200  					{
   201  						name: "k0",
   202  						v:    redis.VectorValues{Val: []float64{1, 0, 0, 0}},
   203  						attr: `{"age": 25, "name": "Alice", "active": true, "scores": [85, 90, 95], "city": "New York"}`,
   204  					},
   205  					{
   206  						name: "k1",
   207  						v:    redis.VectorValues{Val: []float64{0, 1, 0, 0}},
   208  						attr: `{"age": 30, "name": "Bob", "active": false, "scores": [70, 75, 80], "city": "Boston"}`,
   209  					},
   210  					{
   211  						name: "k2",
   212  						v:    redis.VectorValues{Val: []float64{0, 0, 1, 0}},
   213  						attr: `{"age": 35, "name": "Charlie", "scores": [60, 65, 70], "city": "Seattle"}`,
   214  					},
   215  					{
   216  						name: "k3",
   217  						v:    redis.VectorValues{Val: []float64{0, 0, 0, 1}},
   218  					},
   219  					{
   220  						name: "k4",
   221  						v:    redis.VectorValues{Val: []float64{0.5, 0.5, 0, 0}},
   222  						attr: `invalid json`,
   223  					},
   224  				}
   225  
   226  				// If the key doesn't exist, return null error
   227  				_, err := client.VRandMember(ctx, vecName).Result()
   228  				expectEqual(err.Error(), proto.Nil.Error())
   229  
   230  				// If the key doesn't exist, return an empty array
   231  				res, err := client.VRandMemberCount(ctx, vecName, 3).Result()
   232  				expectNil(err)
   233  				expectEqual(len(res), 0)
   234  
   235  				for _, v := range vals {
   236  					ok, err := client.VAdd(ctx, vecName, v.name, &v.v).Result()
   237  					expectNil(err)
   238  					expectTrue(ok)
   239  					if len(v.attr) > 0 {
   240  						ok, err = client.VSetAttr(ctx, vecName, v.name, v.attr).Result()
   241  						expectNil(err)
   242  						expectTrue(ok)
   243  					}
   244  				}
   245  
   246  				// VGetAttr
   247  				attr, err := client.VGetAttr(ctx, vecName, vals[1].name).Result()
   248  				expectNil(err)
   249  				expectEqual(attr, vals[1].attr)
   250  
   251  				// VRandMember
   252  				_, err = client.VRandMember(ctx, vecName).Result()
   253  				expectNil(err)
   254  
   255  				res, err = client.VRandMemberCount(ctx, vecName, 3).Result()
   256  				expectNil(err)
   257  				expectEqual(len(res), 3)
   258  
   259  				res, err = client.VRandMemberCount(ctx, vecName, 10).Result()
   260  				expectNil(err)
   261  				expectEqual(len(res), len(vals))
   262  
   263  				// test equality
   264  				sim, err := client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{
   265  					Filter: `.age == 25`,
   266  				}).Result()
   267  				expectNil(err)
   268  				expectEqual(len(sim), 1)
   269  				expectEqual(sim[0], vals[0].name)
   270  
   271  				// test greater than
   272  				sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{
   273  					Filter: `.age > 25`,
   274  				}).Result()
   275  				expectNil(err)
   276  				expectEqual(len(sim), 2)
   277  
   278  				// test less than or equal
   279  				sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{
   280  					Filter: `.age <= 30`,
   281  				}).Result()
   282  				expectNil(err)
   283  				expectEqual(len(sim), 2)
   284  
   285  				// test string equality
   286  				sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{
   287  					Filter: `.name == "Alice"`,
   288  				}).Result()
   289  				expectNil(err)
   290  				expectEqual(len(sim), 1)
   291  				expectEqual(sim[0], vals[0].name)
   292  
   293  				// test string inequality
   294  				sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{
   295  					Filter: `.name != "Alice"`,
   296  				}).Result()
   297  				expectNil(err)
   298  				expectEqual(len(sim), 2)
   299  
   300  				// test bool
   301  				sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{
   302  					Filter: `.active`,
   303  				}).Result()
   304  				expectNil(err)
   305  				expectEqual(len(sim), 1)
   306  				expectEqual(sim[0], vals[0].name)
   307  
   308  				// test logical add
   309  				sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{
   310  					Filter: `.age > 20 and .age < 30`,
   311  				}).Result()
   312  				expectNil(err)
   313  				expectEqual(len(sim), 1)
   314  				expectEqual(sim[0], vals[0].name)
   315  
   316  				// test logical or
   317  				sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{
   318  					Filter: `.age < 30 or .age > 35`,
   319  				}).Result()
   320  				expectNil(err)
   321  				expectEqual(len(sim), 1)
   322  				expectEqual(sim[0], vals[0].name)
   323  			})
   324  		})
   325  	}
   326  })
   327  

View as plain text