1 package pool
2
3 import (
4 "context"
5 "errors"
6 "net"
7 "sync"
8 "sync/atomic"
9 "time"
10
11 "github.com/redis/go-redis/v9/internal"
12 )
13
14 var (
15
16 ErrClosed = errors.New("redis: client is closed")
17
18
19
20 ErrPoolExhausted = errors.New("redis: connection pool exhausted")
21
22
23 ErrPoolTimeout = errors.New("redis: connection pool timeout")
24 )
25
26 var timers = sync.Pool{
27 New: func() interface{} {
28 t := time.NewTimer(time.Hour)
29 t.Stop()
30 return t
31 },
32 }
33
34
35 type Stats struct {
36 Hits uint32
37 Misses uint32
38 Timeouts uint32
39 WaitCount uint32
40 WaitDurationNs int64
41
42 TotalConns uint32
43 IdleConns uint32
44 StaleConns uint32
45 }
46
47 type Pooler interface {
48 NewConn(context.Context) (*Conn, error)
49 CloseConn(*Conn) error
50
51 Get(context.Context) (*Conn, error)
52 Put(context.Context, *Conn)
53 Remove(context.Context, *Conn, error)
54
55 Len() int
56 IdleLen() int
57 Stats() *Stats
58
59 Close() error
60 }
61
62 type Options struct {
63 Dialer func(context.Context) (net.Conn, error)
64
65 PoolFIFO bool
66 PoolSize int
67 DialTimeout time.Duration
68 PoolTimeout time.Duration
69 MinIdleConns int
70 MaxIdleConns int
71 MaxActiveConns int
72 ConnMaxIdleTime time.Duration
73 ConnMaxLifetime time.Duration
74
75 ReadBufferSize int
76 WriteBufferSize int
77 }
78
79 type lastDialErrorWrap struct {
80 err error
81 }
82
83 type ConnPool struct {
84 cfg *Options
85
86 dialErrorsNum uint32
87 lastDialError atomic.Value
88
89 queue chan struct{}
90
91 connsMu sync.Mutex
92 conns []*Conn
93 idleConns []*Conn
94
95 poolSize int
96 idleConnsLen int
97
98 stats Stats
99 waitDurationNs atomic.Int64
100
101 _closed uint32
102 }
103
104 var _ Pooler = (*ConnPool)(nil)
105
106 func NewConnPool(opt *Options) *ConnPool {
107 p := &ConnPool{
108 cfg: opt,
109
110 queue: make(chan struct{}, opt.PoolSize),
111 conns: make([]*Conn, 0, opt.PoolSize),
112 idleConns: make([]*Conn, 0, opt.PoolSize),
113 }
114
115 p.connsMu.Lock()
116 p.checkMinIdleConns()
117 p.connsMu.Unlock()
118
119 return p
120 }
121
122 func (p *ConnPool) checkMinIdleConns() {
123 if p.cfg.MinIdleConns == 0 {
124 return
125 }
126 for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns {
127 select {
128 case p.queue <- struct{}{}:
129 p.poolSize++
130 p.idleConnsLen++
131
132 go func() {
133 defer func() {
134 if err := recover(); err != nil {
135 p.connsMu.Lock()
136 p.poolSize--
137 p.idleConnsLen--
138 p.connsMu.Unlock()
139
140 p.freeTurn()
141 internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err)
142 }
143 }()
144
145 err := p.addIdleConn()
146 if err != nil && err != ErrClosed {
147 p.connsMu.Lock()
148 p.poolSize--
149 p.idleConnsLen--
150 p.connsMu.Unlock()
151 }
152
153 p.freeTurn()
154 }()
155 default:
156 return
157 }
158 }
159 }
160
161 func (p *ConnPool) addIdleConn() error {
162 ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout)
163 defer cancel()
164
165 cn, err := p.dialConn(ctx, true)
166 if err != nil {
167 return err
168 }
169
170 p.connsMu.Lock()
171 defer p.connsMu.Unlock()
172
173
174 if p.closed() {
175 _ = cn.Close()
176 return ErrClosed
177 }
178
179 p.conns = append(p.conns, cn)
180 p.idleConns = append(p.idleConns, cn)
181 return nil
182 }
183
184 func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
185 return p.newConn(ctx, false)
186 }
187
188 func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
189 if p.closed() {
190 return nil, ErrClosed
191 }
192
193 p.connsMu.Lock()
194 if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
195 p.connsMu.Unlock()
196 return nil, ErrPoolExhausted
197 }
198 p.connsMu.Unlock()
199
200 cn, err := p.dialConn(ctx, pooled)
201 if err != nil {
202 return nil, err
203 }
204
205 p.connsMu.Lock()
206 defer p.connsMu.Unlock()
207
208 if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns {
209 _ = cn.Close()
210 return nil, ErrPoolExhausted
211 }
212
213 p.conns = append(p.conns, cn)
214 if pooled {
215
216 if p.poolSize >= p.cfg.PoolSize {
217 cn.pooled = false
218 } else {
219 p.poolSize++
220 }
221 }
222
223 return cn, nil
224 }
225
226 func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
227 if p.closed() {
228 return nil, ErrClosed
229 }
230
231 if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.cfg.PoolSize) {
232 return nil, p.getLastDialError()
233 }
234
235 netConn, err := p.cfg.Dialer(ctx)
236 if err != nil {
237 p.setLastDialError(err)
238 if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) {
239 go p.tryDial()
240 }
241 return nil, err
242 }
243
244 cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize)
245 cn.pooled = pooled
246 return cn, nil
247 }
248
249 func (p *ConnPool) tryDial() {
250 for {
251 if p.closed() {
252 return
253 }
254
255 ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout)
256
257 conn, err := p.cfg.Dialer(ctx)
258 if err != nil {
259 p.setLastDialError(err)
260 time.Sleep(time.Second)
261 cancel()
262 continue
263 }
264
265 atomic.StoreUint32(&p.dialErrorsNum, 0)
266 _ = conn.Close()
267 cancel()
268 return
269 }
270 }
271
272 func (p *ConnPool) setLastDialError(err error) {
273 p.lastDialError.Store(&lastDialErrorWrap{err: err})
274 }
275
276 func (p *ConnPool) getLastDialError() error {
277 err, _ := p.lastDialError.Load().(*lastDialErrorWrap)
278 if err != nil {
279 return err.err
280 }
281 return nil
282 }
283
284
285 func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
286 if p.closed() {
287 return nil, ErrClosed
288 }
289
290 if err := p.waitTurn(ctx); err != nil {
291 return nil, err
292 }
293
294 for {
295 p.connsMu.Lock()
296 cn, err := p.popIdle()
297 p.connsMu.Unlock()
298
299 if err != nil {
300 p.freeTurn()
301 return nil, err
302 }
303
304 if cn == nil {
305 break
306 }
307
308 if !p.isHealthyConn(cn) {
309 _ = p.CloseConn(cn)
310 continue
311 }
312
313 atomic.AddUint32(&p.stats.Hits, 1)
314 return cn, nil
315 }
316
317 atomic.AddUint32(&p.stats.Misses, 1)
318
319 newcn, err := p.newConn(ctx, true)
320 if err != nil {
321 p.freeTurn()
322 return nil, err
323 }
324
325 return newcn, nil
326 }
327
328 func (p *ConnPool) waitTurn(ctx context.Context) error {
329 select {
330 case <-ctx.Done():
331 return ctx.Err()
332 default:
333 }
334
335 select {
336 case p.queue <- struct{}{}:
337 return nil
338 default:
339 }
340
341 start := time.Now()
342 timer := timers.Get().(*time.Timer)
343 defer timers.Put(timer)
344 timer.Reset(p.cfg.PoolTimeout)
345
346 select {
347 case <-ctx.Done():
348 if !timer.Stop() {
349 <-timer.C
350 }
351 return ctx.Err()
352 case p.queue <- struct{}{}:
353 p.waitDurationNs.Add(time.Since(start).Nanoseconds())
354 atomic.AddUint32(&p.stats.WaitCount, 1)
355 if !timer.Stop() {
356 <-timer.C
357 }
358 return nil
359 case <-timer.C:
360 atomic.AddUint32(&p.stats.Timeouts, 1)
361 return ErrPoolTimeout
362 }
363 }
364
365 func (p *ConnPool) freeTurn() {
366 <-p.queue
367 }
368
369 func (p *ConnPool) popIdle() (*Conn, error) {
370 if p.closed() {
371 return nil, ErrClosed
372 }
373 n := len(p.idleConns)
374 if n == 0 {
375 return nil, nil
376 }
377
378 var cn *Conn
379 if p.cfg.PoolFIFO {
380 cn = p.idleConns[0]
381 copy(p.idleConns, p.idleConns[1:])
382 p.idleConns = p.idleConns[:n-1]
383 } else {
384 idx := n - 1
385 cn = p.idleConns[idx]
386 p.idleConns = p.idleConns[:idx]
387 }
388 p.idleConnsLen--
389 p.checkMinIdleConns()
390 return cn, nil
391 }
392
393 func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
394 if cn.rd.Buffered() > 0 {
395 internal.Logger.Printf(ctx, "Conn has unread data")
396 p.Remove(ctx, cn, BadConnError{})
397 return
398 }
399
400 if !cn.pooled {
401 p.Remove(ctx, cn, nil)
402 return
403 }
404
405 var shouldCloseConn bool
406
407 p.connsMu.Lock()
408
409 if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns {
410 p.idleConns = append(p.idleConns, cn)
411 p.idleConnsLen++
412 } else {
413 p.removeConn(cn)
414 shouldCloseConn = true
415 }
416
417 p.connsMu.Unlock()
418
419 p.freeTurn()
420
421 if shouldCloseConn {
422 _ = p.closeConn(cn)
423 }
424 }
425
426 func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) {
427 p.removeConnWithLock(cn)
428 p.freeTurn()
429 _ = p.closeConn(cn)
430 }
431
432 func (p *ConnPool) CloseConn(cn *Conn) error {
433 p.removeConnWithLock(cn)
434 return p.closeConn(cn)
435 }
436
437 func (p *ConnPool) removeConnWithLock(cn *Conn) {
438 p.connsMu.Lock()
439 defer p.connsMu.Unlock()
440 p.removeConn(cn)
441 }
442
443 func (p *ConnPool) removeConn(cn *Conn) {
444 for i, c := range p.conns {
445 if c == cn {
446 p.conns = append(p.conns[:i], p.conns[i+1:]...)
447 if cn.pooled {
448 p.poolSize--
449 p.checkMinIdleConns()
450 }
451 break
452 }
453 }
454 atomic.AddUint32(&p.stats.StaleConns, 1)
455 }
456
457 func (p *ConnPool) closeConn(cn *Conn) error {
458 return cn.Close()
459 }
460
461
462 func (p *ConnPool) Len() int {
463 p.connsMu.Lock()
464 n := len(p.conns)
465 p.connsMu.Unlock()
466 return n
467 }
468
469
470 func (p *ConnPool) IdleLen() int {
471 p.connsMu.Lock()
472 n := p.idleConnsLen
473 p.connsMu.Unlock()
474 return n
475 }
476
477 func (p *ConnPool) Stats() *Stats {
478 return &Stats{
479 Hits: atomic.LoadUint32(&p.stats.Hits),
480 Misses: atomic.LoadUint32(&p.stats.Misses),
481 Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
482 WaitCount: atomic.LoadUint32(&p.stats.WaitCount),
483 WaitDurationNs: p.waitDurationNs.Load(),
484
485 TotalConns: uint32(p.Len()),
486 IdleConns: uint32(p.IdleLen()),
487 StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
488 }
489 }
490
491 func (p *ConnPool) closed() bool {
492 return atomic.LoadUint32(&p._closed) == 1
493 }
494
495 func (p *ConnPool) Filter(fn func(*Conn) bool) error {
496 p.connsMu.Lock()
497 defer p.connsMu.Unlock()
498
499 var firstErr error
500 for _, cn := range p.conns {
501 if fn(cn) {
502 if err := p.closeConn(cn); err != nil && firstErr == nil {
503 firstErr = err
504 }
505 }
506 }
507 return firstErr
508 }
509
510 func (p *ConnPool) Close() error {
511 if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
512 return ErrClosed
513 }
514
515 var firstErr error
516 p.connsMu.Lock()
517 for _, cn := range p.conns {
518 if err := p.closeConn(cn); err != nil && firstErr == nil {
519 firstErr = err
520 }
521 }
522 p.conns = nil
523 p.poolSize = 0
524 p.idleConns = nil
525 p.idleConnsLen = 0
526 p.connsMu.Unlock()
527
528 return firstErr
529 }
530
531 func (p *ConnPool) isHealthyConn(cn *Conn) bool {
532 now := time.Now()
533
534 if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
535 return false
536 }
537 if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime {
538 return false
539 }
540
541 if connCheck(cn.netConn) != nil {
542 return false
543 }
544
545 cn.SetUsedAt(now)
546 return true
547 }
548
View as plain text