1 package redis
2
3 import (
4 "context"
5 "fmt"
6 "strings"
7 "sync"
8 "time"
9
10 "github.com/redis/go-redis/v9/internal"
11 "github.com/redis/go-redis/v9/internal/pool"
12 "github.com/redis/go-redis/v9/internal/proto"
13 )
14
15
16
17
18
19
20
21 type PubSub struct {
22 opt *Options
23
24 newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
25 closeConn func(*pool.Conn) error
26
27 mu sync.Mutex
28 cn *pool.Conn
29 channels map[string]struct{}
30 patterns map[string]struct{}
31 schannels map[string]struct{}
32
33 closed bool
34 exit chan struct{}
35
36 cmd *Cmd
37
38 chOnce sync.Once
39 msgCh *channel
40 allCh *channel
41 }
42
43 func (c *PubSub) init() {
44 c.exit = make(chan struct{})
45 }
46
47 func (c *PubSub) String() string {
48 c.mu.Lock()
49 defer c.mu.Unlock()
50
51 channels := mapKeys(c.channels)
52 channels = append(channels, mapKeys(c.patterns)...)
53 channels = append(channels, mapKeys(c.schannels)...)
54 return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
55 }
56
57 func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
58 c.mu.Lock()
59 cn, err := c.conn(ctx, nil)
60 c.mu.Unlock()
61 return cn, err
62 }
63
64 func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
65 if c.closed {
66 return nil, pool.ErrClosed
67 }
68 if c.cn != nil {
69 return c.cn, nil
70 }
71
72 channels := mapKeys(c.channels)
73 channels = append(channels, newChannels...)
74
75 cn, err := c.newConn(ctx, channels)
76 if err != nil {
77 return nil, err
78 }
79
80 if err := c.resubscribe(ctx, cn); err != nil {
81 _ = c.closeConn(cn)
82 return nil, err
83 }
84
85 c.cn = cn
86 return cn, nil
87 }
88
89 func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
90 return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
91 return writeCmd(wr, cmd)
92 })
93 }
94
95 func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
96 var firstErr error
97
98 if len(c.channels) > 0 {
99 firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
100 }
101
102 if len(c.patterns) > 0 {
103 err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
104 if err != nil && firstErr == nil {
105 firstErr = err
106 }
107 }
108
109 if len(c.schannels) > 0 {
110 err := c._subscribe(ctx, cn, "ssubscribe", mapKeys(c.schannels))
111 if err != nil && firstErr == nil {
112 firstErr = err
113 }
114 }
115
116 return firstErr
117 }
118
119 func mapKeys(m map[string]struct{}) []string {
120 s := make([]string, len(m))
121 i := 0
122 for k := range m {
123 s[i] = k
124 i++
125 }
126 return s
127 }
128
129 func (c *PubSub) _subscribe(
130 ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
131 ) error {
132 args := make([]interface{}, 0, 1+len(channels))
133 args = append(args, redisCmd)
134 for _, channel := range channels {
135 args = append(args, channel)
136 }
137 cmd := NewSliceCmd(ctx, args...)
138 return c.writeCmd(ctx, cn, cmd)
139 }
140
141 func (c *PubSub) releaseConnWithLock(
142 ctx context.Context,
143 cn *pool.Conn,
144 err error,
145 allowTimeout bool,
146 ) {
147 c.mu.Lock()
148 c.releaseConn(ctx, cn, err, allowTimeout)
149 c.mu.Unlock()
150 }
151
152 func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
153 if c.cn != cn {
154 return
155 }
156 if isBadConn(err, allowTimeout, c.opt.Addr) {
157 c.reconnect(ctx, err)
158 }
159 }
160
161 func (c *PubSub) reconnect(ctx context.Context, reason error) {
162 _ = c.closeTheCn(reason)
163 _, _ = c.conn(ctx, nil)
164 }
165
166 func (c *PubSub) closeTheCn(reason error) error {
167 if c.cn == nil {
168 return nil
169 }
170 if !c.closed {
171 internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
172 }
173 err := c.closeConn(c.cn)
174 c.cn = nil
175 return err
176 }
177
178 func (c *PubSub) Close() error {
179 c.mu.Lock()
180 defer c.mu.Unlock()
181
182 if c.closed {
183 return pool.ErrClosed
184 }
185 c.closed = true
186 close(c.exit)
187
188 return c.closeTheCn(pool.ErrClosed)
189 }
190
191
192
193 func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
194 c.mu.Lock()
195 defer c.mu.Unlock()
196
197 err := c.subscribe(ctx, "subscribe", channels...)
198 if c.channels == nil {
199 c.channels = make(map[string]struct{})
200 }
201 for _, s := range channels {
202 c.channels[s] = struct{}{}
203 }
204 return err
205 }
206
207
208
209 func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
210 c.mu.Lock()
211 defer c.mu.Unlock()
212
213 err := c.subscribe(ctx, "psubscribe", patterns...)
214 if c.patterns == nil {
215 c.patterns = make(map[string]struct{})
216 }
217 for _, s := range patterns {
218 c.patterns[s] = struct{}{}
219 }
220 return err
221 }
222
223
224 func (c *PubSub) SSubscribe(ctx context.Context, channels ...string) error {
225 c.mu.Lock()
226 defer c.mu.Unlock()
227
228 err := c.subscribe(ctx, "ssubscribe", channels...)
229 if c.schannels == nil {
230 c.schannels = make(map[string]struct{})
231 }
232 for _, s := range channels {
233 c.schannels[s] = struct{}{}
234 }
235 return err
236 }
237
238
239
240 func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
241 c.mu.Lock()
242 defer c.mu.Unlock()
243
244 if len(channels) > 0 {
245 for _, channel := range channels {
246 delete(c.channels, channel)
247 }
248 } else {
249
250 for channel := range c.channels {
251 delete(c.channels, channel)
252 }
253 }
254
255 err := c.subscribe(ctx, "unsubscribe", channels...)
256 return err
257 }
258
259
260
261 func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
262 c.mu.Lock()
263 defer c.mu.Unlock()
264
265 if len(patterns) > 0 {
266 for _, pattern := range patterns {
267 delete(c.patterns, pattern)
268 }
269 } else {
270
271 for pattern := range c.patterns {
272 delete(c.patterns, pattern)
273 }
274 }
275
276 err := c.subscribe(ctx, "punsubscribe", patterns...)
277 return err
278 }
279
280
281
282 func (c *PubSub) SUnsubscribe(ctx context.Context, channels ...string) error {
283 c.mu.Lock()
284 defer c.mu.Unlock()
285
286 if len(channels) > 0 {
287 for _, channel := range channels {
288 delete(c.schannels, channel)
289 }
290 } else {
291
292 for channel := range c.schannels {
293 delete(c.schannels, channel)
294 }
295 }
296
297 err := c.subscribe(ctx, "sunsubscribe", channels...)
298 return err
299 }
300
301 func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
302 cn, err := c.conn(ctx, channels)
303 if err != nil {
304 return err
305 }
306
307 err = c._subscribe(ctx, cn, redisCmd, channels)
308 c.releaseConn(ctx, cn, err, false)
309 return err
310 }
311
312 func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
313 args := []interface{}{"ping"}
314 if len(payload) == 1 {
315 args = append(args, payload[0])
316 }
317 cmd := NewCmd(ctx, args...)
318
319 c.mu.Lock()
320 defer c.mu.Unlock()
321
322 cn, err := c.conn(ctx, nil)
323 if err != nil {
324 return err
325 }
326
327 err = c.writeCmd(ctx, cn, cmd)
328 c.releaseConn(ctx, cn, err, false)
329 return err
330 }
331
332
333 type Subscription struct {
334
335 Kind string
336
337 Channel string
338
339 Count int
340 }
341
342 func (m *Subscription) String() string {
343 return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
344 }
345
346
347 type Message struct {
348 Channel string
349 Pattern string
350 Payload string
351 PayloadSlice []string
352 }
353
354 func (m *Message) String() string {
355 return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
356 }
357
358
359 type Pong struct {
360 Payload string
361 }
362
363 func (p *Pong) String() string {
364 if p.Payload != "" {
365 return fmt.Sprintf("Pong<%s>", p.Payload)
366 }
367 return "Pong"
368 }
369
370 func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
371 switch reply := reply.(type) {
372 case string:
373 return &Pong{
374 Payload: reply,
375 }, nil
376 case []interface{}:
377 switch kind := reply[0].(string); kind {
378 case "subscribe", "unsubscribe", "psubscribe", "punsubscribe", "ssubscribe", "sunsubscribe":
379
380 channel, _ := reply[1].(string)
381 return &Subscription{
382 Kind: kind,
383 Channel: channel,
384 Count: int(reply[2].(int64)),
385 }, nil
386 case "message", "smessage":
387 switch payload := reply[2].(type) {
388 case string:
389 return &Message{
390 Channel: reply[1].(string),
391 Payload: payload,
392 }, nil
393 case []interface{}:
394 ss := make([]string, len(payload))
395 for i, s := range payload {
396 ss[i] = s.(string)
397 }
398 return &Message{
399 Channel: reply[1].(string),
400 PayloadSlice: ss,
401 }, nil
402 default:
403 return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload)
404 }
405 case "pmessage":
406 return &Message{
407 Pattern: reply[1].(string),
408 Channel: reply[2].(string),
409 Payload: reply[3].(string),
410 }, nil
411 case "pong":
412 return &Pong{
413 Payload: reply[1].(string),
414 }, nil
415 default:
416 return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
417 }
418 default:
419 return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
420 }
421 }
422
423
424
425
426 func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
427 if c.cmd == nil {
428 c.cmd = NewCmd(ctx)
429 }
430
431
432
433 cn, err := c.connWithLock(ctx)
434 if err != nil {
435 return nil, err
436 }
437
438 err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
439 return c.cmd.readReply(rd)
440 })
441
442 c.releaseConnWithLock(ctx, cn, err, timeout > 0)
443
444 if err != nil {
445 return nil, err
446 }
447
448 return c.newMessage(c.cmd.Val())
449 }
450
451
452
453
454 func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
455 return c.ReceiveTimeout(ctx, 0)
456 }
457
458
459
460
461 func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
462 for {
463 msg, err := c.Receive(ctx)
464 if err != nil {
465 return nil, err
466 }
467
468 switch msg := msg.(type) {
469 case *Subscription:
470
471 case *Pong:
472
473 case *Message:
474 return msg, nil
475 default:
476 err := fmt.Errorf("redis: unknown message: %T", msg)
477 return nil, err
478 }
479 }
480 }
481
482 func (c *PubSub) getContext() context.Context {
483 if c.cmd != nil {
484 return c.cmd.ctx
485 }
486 return context.Background()
487 }
488
489
490
491
492
493
494
495
496
497
498 func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
499 c.chOnce.Do(func() {
500 c.msgCh = newChannel(c, opts...)
501 c.msgCh.initMsgChan()
502 })
503 if c.msgCh == nil {
504 err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
505 panic(err)
506 }
507 return c.msgCh.msgCh
508 }
509
510
511
512
513
514 func (c *PubSub) ChannelSize(size int) <-chan *Message {
515 return c.Channel(WithChannelSize(size))
516 }
517
518
519
520
521
522
523 func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interface{} {
524 c.chOnce.Do(func() {
525 c.allCh = newChannel(c, opts...)
526 c.allCh.initAllChan()
527 })
528 if c.allCh == nil {
529 err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
530 panic(err)
531 }
532 return c.allCh.allCh
533 }
534
535 type ChannelOption func(c *channel)
536
537
538
539
540 func WithChannelSize(size int) ChannelOption {
541 return func(c *channel) {
542 c.chanSize = size
543 }
544 }
545
546
547
548
549
550
551 func WithChannelHealthCheckInterval(d time.Duration) ChannelOption {
552 return func(c *channel) {
553 c.checkInterval = d
554 }
555 }
556
557
558
559
560
561 func WithChannelSendTimeout(d time.Duration) ChannelOption {
562 return func(c *channel) {
563 c.chanSendTimeout = d
564 }
565 }
566
567 type channel struct {
568 pubSub *PubSub
569
570 msgCh chan *Message
571 allCh chan interface{}
572 ping chan struct{}
573
574 chanSize int
575 chanSendTimeout time.Duration
576 checkInterval time.Duration
577 }
578
579 func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
580 c := &channel{
581 pubSub: pubSub,
582
583 chanSize: 100,
584 chanSendTimeout: time.Minute,
585 checkInterval: 3 * time.Second,
586 }
587 for _, opt := range opts {
588 opt(c)
589 }
590 if c.checkInterval > 0 {
591 c.initHealthCheck()
592 }
593 return c
594 }
595
596 func (c *channel) initHealthCheck() {
597 ctx := context.TODO()
598 c.ping = make(chan struct{}, 1)
599
600 go func() {
601 timer := time.NewTimer(time.Minute)
602 timer.Stop()
603
604 for {
605 timer.Reset(c.checkInterval)
606 select {
607 case <-c.ping:
608 if !timer.Stop() {
609 <-timer.C
610 }
611 case <-timer.C:
612 if pingErr := c.pubSub.Ping(ctx); pingErr != nil {
613 c.pubSub.mu.Lock()
614 c.pubSub.reconnect(ctx, pingErr)
615 c.pubSub.mu.Unlock()
616 }
617 case <-c.pubSub.exit:
618 return
619 }
620 }
621 }()
622 }
623
624
625 func (c *channel) initMsgChan() {
626 ctx := context.TODO()
627 c.msgCh = make(chan *Message, c.chanSize)
628
629 go func() {
630 timer := time.NewTimer(time.Minute)
631 timer.Stop()
632
633 var errCount int
634 for {
635 msg, err := c.pubSub.Receive(ctx)
636 if err != nil {
637 if err == pool.ErrClosed {
638 close(c.msgCh)
639 return
640 }
641 if errCount > 0 {
642 time.Sleep(100 * time.Millisecond)
643 }
644 errCount++
645 continue
646 }
647
648 errCount = 0
649
650
651 select {
652 case c.ping <- struct{}{}:
653 default:
654 }
655
656 switch msg := msg.(type) {
657 case *Subscription:
658
659 case *Pong:
660
661 case *Message:
662 timer.Reset(c.chanSendTimeout)
663 select {
664 case c.msgCh <- msg:
665 if !timer.Stop() {
666 <-timer.C
667 }
668 case <-timer.C:
669 internal.Logger.Printf(
670 ctx, "redis: %s channel is full for %s (message is dropped)",
671 c, c.chanSendTimeout)
672 }
673 default:
674 internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
675 }
676 }
677 }()
678 }
679
680
681 func (c *channel) initAllChan() {
682 ctx := context.TODO()
683 c.allCh = make(chan interface{}, c.chanSize)
684
685 go func() {
686 timer := time.NewTimer(time.Minute)
687 timer.Stop()
688
689 var errCount int
690 for {
691 msg, err := c.pubSub.Receive(ctx)
692 if err != nil {
693 if err == pool.ErrClosed {
694 close(c.allCh)
695 return
696 }
697 if errCount > 0 {
698 time.Sleep(100 * time.Millisecond)
699 }
700 errCount++
701 continue
702 }
703
704 errCount = 0
705
706
707 select {
708 case c.ping <- struct{}{}:
709 default:
710 }
711
712 switch msg := msg.(type) {
713 case *Pong:
714
715 case *Subscription, *Message:
716 timer.Reset(c.chanSendTimeout)
717 select {
718 case c.allCh <- msg:
719 if !timer.Stop() {
720 <-timer.C
721 }
722 case <-timer.C:
723 internal.Logger.Printf(
724 ctx, "redis: %s channel is full for %s (message is dropped)",
725 c, c.chanSendTimeout)
726 }
727 default:
728 internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
729 }
730 }
731 }()
732 }
733
View as plain text