...

Source file src/github.com/redis/go-redis/v9/pubsub.go

Documentation: github.com/redis/go-redis/v9

     1  package redis
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/redis/go-redis/v9/internal"
    11  	"github.com/redis/go-redis/v9/internal/pool"
    12  	"github.com/redis/go-redis/v9/internal/proto"
    13  )
    14  
    15  // PubSub implements Pub/Sub commands as described in
    16  // http://redis.io/topics/pubsub. Message receiving is NOT safe
    17  // for concurrent use by multiple goroutines.
    18  //
    19  // PubSub automatically reconnects to Redis Server and resubscribes
    20  // to the channels in case of network errors.
    21  type PubSub struct {
    22  	opt *Options
    23  
    24  	newConn   func(ctx context.Context, channels []string) (*pool.Conn, error)
    25  	closeConn func(*pool.Conn) error
    26  
    27  	mu        sync.Mutex
    28  	cn        *pool.Conn
    29  	channels  map[string]struct{}
    30  	patterns  map[string]struct{}
    31  	schannels map[string]struct{}
    32  
    33  	closed bool
    34  	exit   chan struct{}
    35  
    36  	cmd *Cmd
    37  
    38  	chOnce sync.Once
    39  	msgCh  *channel
    40  	allCh  *channel
    41  }
    42  
    43  func (c *PubSub) init() {
    44  	c.exit = make(chan struct{})
    45  }
    46  
    47  func (c *PubSub) String() string {
    48  	c.mu.Lock()
    49  	defer c.mu.Unlock()
    50  
    51  	channels := mapKeys(c.channels)
    52  	channels = append(channels, mapKeys(c.patterns)...)
    53  	channels = append(channels, mapKeys(c.schannels)...)
    54  	return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
    55  }
    56  
    57  func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
    58  	c.mu.Lock()
    59  	cn, err := c.conn(ctx, nil)
    60  	c.mu.Unlock()
    61  	return cn, err
    62  }
    63  
    64  func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
    65  	if c.closed {
    66  		return nil, pool.ErrClosed
    67  	}
    68  	if c.cn != nil {
    69  		return c.cn, nil
    70  	}
    71  
    72  	channels := mapKeys(c.channels)
    73  	channels = append(channels, newChannels...)
    74  
    75  	cn, err := c.newConn(ctx, channels)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	if err := c.resubscribe(ctx, cn); err != nil {
    81  		_ = c.closeConn(cn)
    82  		return nil, err
    83  	}
    84  
    85  	c.cn = cn
    86  	return cn, nil
    87  }
    88  
    89  func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
    90  	return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
    91  		return writeCmd(wr, cmd)
    92  	})
    93  }
    94  
    95  func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
    96  	var firstErr error
    97  
    98  	if len(c.channels) > 0 {
    99  		firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
   100  	}
   101  
   102  	if len(c.patterns) > 0 {
   103  		err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
   104  		if err != nil && firstErr == nil {
   105  			firstErr = err
   106  		}
   107  	}
   108  
   109  	if len(c.schannels) > 0 {
   110  		err := c._subscribe(ctx, cn, "ssubscribe", mapKeys(c.schannels))
   111  		if err != nil && firstErr == nil {
   112  			firstErr = err
   113  		}
   114  	}
   115  
   116  	return firstErr
   117  }
   118  
   119  func mapKeys(m map[string]struct{}) []string {
   120  	s := make([]string, len(m))
   121  	i := 0
   122  	for k := range m {
   123  		s[i] = k
   124  		i++
   125  	}
   126  	return s
   127  }
   128  
   129  func (c *PubSub) _subscribe(
   130  	ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
   131  ) error {
   132  	args := make([]interface{}, 0, 1+len(channels))
   133  	args = append(args, redisCmd)
   134  	for _, channel := range channels {
   135  		args = append(args, channel)
   136  	}
   137  	cmd := NewSliceCmd(ctx, args...)
   138  	return c.writeCmd(ctx, cn, cmd)
   139  }
   140  
   141  func (c *PubSub) releaseConnWithLock(
   142  	ctx context.Context,
   143  	cn *pool.Conn,
   144  	err error,
   145  	allowTimeout bool,
   146  ) {
   147  	c.mu.Lock()
   148  	c.releaseConn(ctx, cn, err, allowTimeout)
   149  	c.mu.Unlock()
   150  }
   151  
   152  func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
   153  	if c.cn != cn {
   154  		return
   155  	}
   156  	if isBadConn(err, allowTimeout, c.opt.Addr) {
   157  		c.reconnect(ctx, err)
   158  	}
   159  }
   160  
   161  func (c *PubSub) reconnect(ctx context.Context, reason error) {
   162  	_ = c.closeTheCn(reason)
   163  	_, _ = c.conn(ctx, nil)
   164  }
   165  
   166  func (c *PubSub) closeTheCn(reason error) error {
   167  	if c.cn == nil {
   168  		return nil
   169  	}
   170  	if !c.closed {
   171  		internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
   172  	}
   173  	err := c.closeConn(c.cn)
   174  	c.cn = nil
   175  	return err
   176  }
   177  
   178  func (c *PubSub) Close() error {
   179  	c.mu.Lock()
   180  	defer c.mu.Unlock()
   181  
   182  	if c.closed {
   183  		return pool.ErrClosed
   184  	}
   185  	c.closed = true
   186  	close(c.exit)
   187  
   188  	return c.closeTheCn(pool.ErrClosed)
   189  }
   190  
   191  // Subscribe the client to the specified channels. It returns
   192  // empty subscription if there are no channels.
   193  func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
   194  	c.mu.Lock()
   195  	defer c.mu.Unlock()
   196  
   197  	err := c.subscribe(ctx, "subscribe", channels...)
   198  	if c.channels == nil {
   199  		c.channels = make(map[string]struct{})
   200  	}
   201  	for _, s := range channels {
   202  		c.channels[s] = struct{}{}
   203  	}
   204  	return err
   205  }
   206  
   207  // PSubscribe the client to the given patterns. It returns
   208  // empty subscription if there are no patterns.
   209  func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
   210  	c.mu.Lock()
   211  	defer c.mu.Unlock()
   212  
   213  	err := c.subscribe(ctx, "psubscribe", patterns...)
   214  	if c.patterns == nil {
   215  		c.patterns = make(map[string]struct{})
   216  	}
   217  	for _, s := range patterns {
   218  		c.patterns[s] = struct{}{}
   219  	}
   220  	return err
   221  }
   222  
   223  // SSubscribe Subscribes the client to the specified shard channels.
   224  func (c *PubSub) SSubscribe(ctx context.Context, channels ...string) error {
   225  	c.mu.Lock()
   226  	defer c.mu.Unlock()
   227  
   228  	err := c.subscribe(ctx, "ssubscribe", channels...)
   229  	if c.schannels == nil {
   230  		c.schannels = make(map[string]struct{})
   231  	}
   232  	for _, s := range channels {
   233  		c.schannels[s] = struct{}{}
   234  	}
   235  	return err
   236  }
   237  
   238  // Unsubscribe the client from the given channels, or from all of
   239  // them if none is given.
   240  func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
   241  	c.mu.Lock()
   242  	defer c.mu.Unlock()
   243  
   244  	if len(channels) > 0 {
   245  		for _, channel := range channels {
   246  			delete(c.channels, channel)
   247  		}
   248  	} else {
   249  		// Unsubscribe from all channels.
   250  		for channel := range c.channels {
   251  			delete(c.channels, channel)
   252  		}
   253  	}
   254  
   255  	err := c.subscribe(ctx, "unsubscribe", channels...)
   256  	return err
   257  }
   258  
   259  // PUnsubscribe the client from the given patterns, or from all of
   260  // them if none is given.
   261  func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
   262  	c.mu.Lock()
   263  	defer c.mu.Unlock()
   264  
   265  	if len(patterns) > 0 {
   266  		for _, pattern := range patterns {
   267  			delete(c.patterns, pattern)
   268  		}
   269  	} else {
   270  		// Unsubscribe from all patterns.
   271  		for pattern := range c.patterns {
   272  			delete(c.patterns, pattern)
   273  		}
   274  	}
   275  
   276  	err := c.subscribe(ctx, "punsubscribe", patterns...)
   277  	return err
   278  }
   279  
   280  // SUnsubscribe unsubscribes the client from the given shard channels,
   281  // or from all of them if none is given.
   282  func (c *PubSub) SUnsubscribe(ctx context.Context, channels ...string) error {
   283  	c.mu.Lock()
   284  	defer c.mu.Unlock()
   285  
   286  	if len(channels) > 0 {
   287  		for _, channel := range channels {
   288  			delete(c.schannels, channel)
   289  		}
   290  	} else {
   291  		// Unsubscribe from all channels.
   292  		for channel := range c.schannels {
   293  			delete(c.schannels, channel)
   294  		}
   295  	}
   296  
   297  	err := c.subscribe(ctx, "sunsubscribe", channels...)
   298  	return err
   299  }
   300  
   301  func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
   302  	cn, err := c.conn(ctx, channels)
   303  	if err != nil {
   304  		return err
   305  	}
   306  
   307  	err = c._subscribe(ctx, cn, redisCmd, channels)
   308  	c.releaseConn(ctx, cn, err, false)
   309  	return err
   310  }
   311  
   312  func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
   313  	args := []interface{}{"ping"}
   314  	if len(payload) == 1 {
   315  		args = append(args, payload[0])
   316  	}
   317  	cmd := NewCmd(ctx, args...)
   318  
   319  	c.mu.Lock()
   320  	defer c.mu.Unlock()
   321  
   322  	cn, err := c.conn(ctx, nil)
   323  	if err != nil {
   324  		return err
   325  	}
   326  
   327  	err = c.writeCmd(ctx, cn, cmd)
   328  	c.releaseConn(ctx, cn, err, false)
   329  	return err
   330  }
   331  
   332  // Subscription received after a successful subscription to channel.
   333  type Subscription struct {
   334  	// Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
   335  	Kind string
   336  	// Channel name we have subscribed to.
   337  	Channel string
   338  	// Number of channels we are currently subscribed to.
   339  	Count int
   340  }
   341  
   342  func (m *Subscription) String() string {
   343  	return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
   344  }
   345  
   346  // Message received as result of a PUBLISH command issued by another client.
   347  type Message struct {
   348  	Channel      string
   349  	Pattern      string
   350  	Payload      string
   351  	PayloadSlice []string
   352  }
   353  
   354  func (m *Message) String() string {
   355  	return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
   356  }
   357  
   358  // Pong received as result of a PING command issued by another client.
   359  type Pong struct {
   360  	Payload string
   361  }
   362  
   363  func (p *Pong) String() string {
   364  	if p.Payload != "" {
   365  		return fmt.Sprintf("Pong<%s>", p.Payload)
   366  	}
   367  	return "Pong"
   368  }
   369  
   370  func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
   371  	switch reply := reply.(type) {
   372  	case string:
   373  		return &Pong{
   374  			Payload: reply,
   375  		}, nil
   376  	case []interface{}:
   377  		switch kind := reply[0].(string); kind {
   378  		case "subscribe", "unsubscribe", "psubscribe", "punsubscribe", "ssubscribe", "sunsubscribe":
   379  			// Can be nil in case of "unsubscribe".
   380  			channel, _ := reply[1].(string)
   381  			return &Subscription{
   382  				Kind:    kind,
   383  				Channel: channel,
   384  				Count:   int(reply[2].(int64)),
   385  			}, nil
   386  		case "message", "smessage":
   387  			switch payload := reply[2].(type) {
   388  			case string:
   389  				return &Message{
   390  					Channel: reply[1].(string),
   391  					Payload: payload,
   392  				}, nil
   393  			case []interface{}:
   394  				ss := make([]string, len(payload))
   395  				for i, s := range payload {
   396  					ss[i] = s.(string)
   397  				}
   398  				return &Message{
   399  					Channel:      reply[1].(string),
   400  					PayloadSlice: ss,
   401  				}, nil
   402  			default:
   403  				return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload)
   404  			}
   405  		case "pmessage":
   406  			return &Message{
   407  				Pattern: reply[1].(string),
   408  				Channel: reply[2].(string),
   409  				Payload: reply[3].(string),
   410  			}, nil
   411  		case "pong":
   412  			return &Pong{
   413  				Payload: reply[1].(string),
   414  			}, nil
   415  		default:
   416  			return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
   417  		}
   418  	default:
   419  		return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
   420  	}
   421  }
   422  
   423  // ReceiveTimeout acts like Receive but returns an error if message
   424  // is not received in time. This is low-level API and in most cases
   425  // Channel should be used instead.
   426  func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
   427  	if c.cmd == nil {
   428  		c.cmd = NewCmd(ctx)
   429  	}
   430  
   431  	// Don't hold the lock to allow subscriptions and pings.
   432  
   433  	cn, err := c.connWithLock(ctx)
   434  	if err != nil {
   435  		return nil, err
   436  	}
   437  
   438  	err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
   439  		return c.cmd.readReply(rd)
   440  	})
   441  
   442  	c.releaseConnWithLock(ctx, cn, err, timeout > 0)
   443  
   444  	if err != nil {
   445  		return nil, err
   446  	}
   447  
   448  	return c.newMessage(c.cmd.Val())
   449  }
   450  
   451  // Receive returns a message as a Subscription, Message, Pong or error.
   452  // See PubSub example for details. This is low-level API and in most cases
   453  // Channel should be used instead.
   454  func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
   455  	return c.ReceiveTimeout(ctx, 0)
   456  }
   457  
   458  // ReceiveMessage returns a Message or error ignoring Subscription and Pong
   459  // messages. This is low-level API and in most cases Channel should be used
   460  // instead.
   461  func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
   462  	for {
   463  		msg, err := c.Receive(ctx)
   464  		if err != nil {
   465  			return nil, err
   466  		}
   467  
   468  		switch msg := msg.(type) {
   469  		case *Subscription:
   470  			// Ignore.
   471  		case *Pong:
   472  			// Ignore.
   473  		case *Message:
   474  			return msg, nil
   475  		default:
   476  			err := fmt.Errorf("redis: unknown message: %T", msg)
   477  			return nil, err
   478  		}
   479  	}
   480  }
   481  
   482  func (c *PubSub) getContext() context.Context {
   483  	if c.cmd != nil {
   484  		return c.cmd.ctx
   485  	}
   486  	return context.Background()
   487  }
   488  
   489  //------------------------------------------------------------------------------
   490  
   491  // Channel returns a Go channel for concurrently receiving messages.
   492  // The channel is closed together with the PubSub. If the Go channel
   493  // is blocked full for 1 minute the message is dropped.
   494  // Receive* APIs can not be used after channel is created.
   495  //
   496  // go-redis periodically sends ping messages to test connection health
   497  // and re-subscribes if ping can not received for 1 minute.
   498  func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
   499  	c.chOnce.Do(func() {
   500  		c.msgCh = newChannel(c, opts...)
   501  		c.msgCh.initMsgChan()
   502  	})
   503  	if c.msgCh == nil {
   504  		err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
   505  		panic(err)
   506  	}
   507  	return c.msgCh.msgCh
   508  }
   509  
   510  // ChannelSize is like Channel, but creates a Go channel
   511  // with specified buffer size.
   512  //
   513  // Deprecated: use Channel(WithChannelSize(size)), remove in v9.
   514  func (c *PubSub) ChannelSize(size int) <-chan *Message {
   515  	return c.Channel(WithChannelSize(size))
   516  }
   517  
   518  // ChannelWithSubscriptions is like Channel, but message type can be either
   519  // *Subscription or *Message. Subscription messages can be used to detect
   520  // reconnections.
   521  //
   522  // ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
   523  func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interface{} {
   524  	c.chOnce.Do(func() {
   525  		c.allCh = newChannel(c, opts...)
   526  		c.allCh.initAllChan()
   527  	})
   528  	if c.allCh == nil {
   529  		err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
   530  		panic(err)
   531  	}
   532  	return c.allCh.allCh
   533  }
   534  
   535  type ChannelOption func(c *channel)
   536  
   537  // WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
   538  //
   539  // The default is 100 messages.
   540  func WithChannelSize(size int) ChannelOption {
   541  	return func(c *channel) {
   542  		c.chanSize = size
   543  	}
   544  }
   545  
   546  // WithChannelHealthCheckInterval specifies the health check interval.
   547  // PubSub will ping Redis Server if it does not receive any messages within the interval.
   548  // To disable health check, use zero interval.
   549  //
   550  // The default is 3 seconds.
   551  func WithChannelHealthCheckInterval(d time.Duration) ChannelOption {
   552  	return func(c *channel) {
   553  		c.checkInterval = d
   554  	}
   555  }
   556  
   557  // WithChannelSendTimeout specifies the channel send timeout after which
   558  // the message is dropped.
   559  //
   560  // The default is 60 seconds.
   561  func WithChannelSendTimeout(d time.Duration) ChannelOption {
   562  	return func(c *channel) {
   563  		c.chanSendTimeout = d
   564  	}
   565  }
   566  
   567  type channel struct {
   568  	pubSub *PubSub
   569  
   570  	msgCh chan *Message
   571  	allCh chan interface{}
   572  	ping  chan struct{}
   573  
   574  	chanSize        int
   575  	chanSendTimeout time.Duration
   576  	checkInterval   time.Duration
   577  }
   578  
   579  func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
   580  	c := &channel{
   581  		pubSub: pubSub,
   582  
   583  		chanSize:        100,
   584  		chanSendTimeout: time.Minute,
   585  		checkInterval:   3 * time.Second,
   586  	}
   587  	for _, opt := range opts {
   588  		opt(c)
   589  	}
   590  	if c.checkInterval > 0 {
   591  		c.initHealthCheck()
   592  	}
   593  	return c
   594  }
   595  
   596  func (c *channel) initHealthCheck() {
   597  	ctx := context.TODO()
   598  	c.ping = make(chan struct{}, 1)
   599  
   600  	go func() {
   601  		timer := time.NewTimer(time.Minute)
   602  		timer.Stop()
   603  
   604  		for {
   605  			timer.Reset(c.checkInterval)
   606  			select {
   607  			case <-c.ping:
   608  				if !timer.Stop() {
   609  					<-timer.C
   610  				}
   611  			case <-timer.C:
   612  				if pingErr := c.pubSub.Ping(ctx); pingErr != nil {
   613  					c.pubSub.mu.Lock()
   614  					c.pubSub.reconnect(ctx, pingErr)
   615  					c.pubSub.mu.Unlock()
   616  				}
   617  			case <-c.pubSub.exit:
   618  				return
   619  			}
   620  		}
   621  	}()
   622  }
   623  
   624  // initMsgChan must be in sync with initAllChan.
   625  func (c *channel) initMsgChan() {
   626  	ctx := context.TODO()
   627  	c.msgCh = make(chan *Message, c.chanSize)
   628  
   629  	go func() {
   630  		timer := time.NewTimer(time.Minute)
   631  		timer.Stop()
   632  
   633  		var errCount int
   634  		for {
   635  			msg, err := c.pubSub.Receive(ctx)
   636  			if err != nil {
   637  				if err == pool.ErrClosed {
   638  					close(c.msgCh)
   639  					return
   640  				}
   641  				if errCount > 0 {
   642  					time.Sleep(100 * time.Millisecond)
   643  				}
   644  				errCount++
   645  				continue
   646  			}
   647  
   648  			errCount = 0
   649  
   650  			// Any message is as good as a ping.
   651  			select {
   652  			case c.ping <- struct{}{}:
   653  			default:
   654  			}
   655  
   656  			switch msg := msg.(type) {
   657  			case *Subscription:
   658  				// Ignore.
   659  			case *Pong:
   660  				// Ignore.
   661  			case *Message:
   662  				timer.Reset(c.chanSendTimeout)
   663  				select {
   664  				case c.msgCh <- msg:
   665  					if !timer.Stop() {
   666  						<-timer.C
   667  					}
   668  				case <-timer.C:
   669  					internal.Logger.Printf(
   670  						ctx, "redis: %s channel is full for %s (message is dropped)",
   671  						c, c.chanSendTimeout)
   672  				}
   673  			default:
   674  				internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
   675  			}
   676  		}
   677  	}()
   678  }
   679  
   680  // initAllChan must be in sync with initMsgChan.
   681  func (c *channel) initAllChan() {
   682  	ctx := context.TODO()
   683  	c.allCh = make(chan interface{}, c.chanSize)
   684  
   685  	go func() {
   686  		timer := time.NewTimer(time.Minute)
   687  		timer.Stop()
   688  
   689  		var errCount int
   690  		for {
   691  			msg, err := c.pubSub.Receive(ctx)
   692  			if err != nil {
   693  				if err == pool.ErrClosed {
   694  					close(c.allCh)
   695  					return
   696  				}
   697  				if errCount > 0 {
   698  					time.Sleep(100 * time.Millisecond)
   699  				}
   700  				errCount++
   701  				continue
   702  			}
   703  
   704  			errCount = 0
   705  
   706  			// Any message is as good as a ping.
   707  			select {
   708  			case c.ping <- struct{}{}:
   709  			default:
   710  			}
   711  
   712  			switch msg := msg.(type) {
   713  			case *Pong:
   714  				// Ignore.
   715  			case *Subscription, *Message:
   716  				timer.Reset(c.chanSendTimeout)
   717  				select {
   718  				case c.allCh <- msg:
   719  					if !timer.Stop() {
   720  						<-timer.C
   721  					}
   722  				case <-timer.C:
   723  					internal.Logger.Printf(
   724  						ctx, "redis: %s channel is full for %s (message is dropped)",
   725  						c, c.chanSendTimeout)
   726  				}
   727  			default:
   728  				internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
   729  			}
   730  		}
   731  	}()
   732  }
   733  

View as plain text