...
1 package proto
2
3 import (
4 "encoding"
5 "fmt"
6 "io"
7 "net"
8 "strconv"
9 "time"
10
11 "github.com/redis/go-redis/v9/internal/util"
12 )
13
14 type writer interface {
15 io.Writer
16 io.ByteWriter
17
18 WriteString(s string) (n int, err error)
19 }
20
21 type Writer struct {
22 writer
23
24 lenBuf []byte
25 numBuf []byte
26 }
27
28 func NewWriter(wr writer) *Writer {
29 return &Writer{
30 writer: wr,
31
32 lenBuf: make([]byte, 64),
33 numBuf: make([]byte, 64),
34 }
35 }
36
37 func (w *Writer) WriteArgs(args []interface{}) error {
38 if err := w.WriteByte(RespArray); err != nil {
39 return err
40 }
41
42 if err := w.writeLen(len(args)); err != nil {
43 return err
44 }
45
46 for _, arg := range args {
47 if err := w.WriteArg(arg); err != nil {
48 return err
49 }
50 }
51
52 return nil
53 }
54
55 func (w *Writer) writeLen(n int) error {
56 w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10)
57 w.lenBuf = append(w.lenBuf, '\r', '\n')
58 _, err := w.Write(w.lenBuf)
59 return err
60 }
61
62 func (w *Writer) WriteArg(v interface{}) error {
63 switch v := v.(type) {
64 case nil:
65 return w.string("")
66 case string:
67 return w.string(v)
68 case *string:
69 if v == nil {
70 return w.string("")
71 }
72 return w.string(*v)
73 case []byte:
74 return w.bytes(v)
75 case int:
76 return w.int(int64(v))
77 case *int:
78 if v == nil {
79 return w.int(0)
80 }
81 return w.int(int64(*v))
82 case int8:
83 return w.int(int64(v))
84 case *int8:
85 if v == nil {
86 return w.int(0)
87 }
88 return w.int(int64(*v))
89 case int16:
90 return w.int(int64(v))
91 case *int16:
92 if v == nil {
93 return w.int(0)
94 }
95 return w.int(int64(*v))
96 case int32:
97 return w.int(int64(v))
98 case *int32:
99 if v == nil {
100 return w.int(0)
101 }
102 return w.int(int64(*v))
103 case int64:
104 return w.int(v)
105 case *int64:
106 if v == nil {
107 return w.int(0)
108 }
109 return w.int(*v)
110 case uint:
111 return w.uint(uint64(v))
112 case *uint:
113 if v == nil {
114 return w.uint(0)
115 }
116 return w.uint(uint64(*v))
117 case uint8:
118 return w.uint(uint64(v))
119 case *uint8:
120 if v == nil {
121 return w.string("")
122 }
123 return w.uint(uint64(*v))
124 case uint16:
125 return w.uint(uint64(v))
126 case *uint16:
127 if v == nil {
128 return w.uint(0)
129 }
130 return w.uint(uint64(*v))
131 case uint32:
132 return w.uint(uint64(v))
133 case *uint32:
134 if v == nil {
135 return w.uint(0)
136 }
137 return w.uint(uint64(*v))
138 case uint64:
139 return w.uint(v)
140 case *uint64:
141 if v == nil {
142 return w.uint(0)
143 }
144 return w.uint(*v)
145 case float32:
146 return w.float(float64(v))
147 case *float32:
148 if v == nil {
149 return w.float(0)
150 }
151 return w.float(float64(*v))
152 case float64:
153 return w.float(v)
154 case *float64:
155 if v == nil {
156 return w.float(0)
157 }
158 return w.float(*v)
159 case bool:
160 if v {
161 return w.int(1)
162 }
163 return w.int(0)
164 case *bool:
165 if v == nil {
166 return w.int(0)
167 }
168 if *v {
169 return w.int(1)
170 }
171 return w.int(0)
172 case time.Time:
173 w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
174 return w.bytes(w.numBuf)
175 case *time.Time:
176 if v == nil {
177 v = &time.Time{}
178 }
179 w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano)
180 return w.bytes(w.numBuf)
181 case time.Duration:
182 return w.int(v.Nanoseconds())
183 case *time.Duration:
184 if v == nil {
185 return w.int(0)
186 }
187 return w.int(v.Nanoseconds())
188 case encoding.BinaryMarshaler:
189 b, err := v.MarshalBinary()
190 if err != nil {
191 return err
192 }
193 return w.bytes(b)
194 case net.IP:
195 return w.bytes(v)
196 default:
197 return fmt.Errorf(
198 "redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)
199 }
200 }
201
202 func (w *Writer) bytes(b []byte) error {
203 if err := w.WriteByte(RespString); err != nil {
204 return err
205 }
206
207 if err := w.writeLen(len(b)); err != nil {
208 return err
209 }
210
211 if _, err := w.Write(b); err != nil {
212 return err
213 }
214
215 return w.crlf()
216 }
217
218 func (w *Writer) string(s string) error {
219 return w.bytes(util.StringToBytes(s))
220 }
221
222 func (w *Writer) uint(n uint64) error {
223 w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10)
224 return w.bytes(w.numBuf)
225 }
226
227 func (w *Writer) int(n int64) error {
228 w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10)
229 return w.bytes(w.numBuf)
230 }
231
232 func (w *Writer) float(f float64) error {
233 w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64)
234 return w.bytes(w.numBuf)
235 }
236
237 func (w *Writer) crlf() error {
238 if err := w.WriteByte('\r'); err != nil {
239 return err
240 }
241 return w.WriteByte('\n')
242 }
243
View as plain text