1 package auth
2
3 import (
4 "errors"
5 "strings"
6 "sync"
7 "testing"
8 "time"
9 )
10
11 type mockStreamingProvider struct {
12 credentials Credentials
13 err error
14 updates chan Credentials
15 }
16
17 func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
18 return &mockStreamingProvider{
19 credentials: initialCreds,
20 updates: make(chan Credentials, 10),
21 }
22 }
23
24 func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
25 if m.err != nil {
26 return nil, nil, m.err
27 }
28
29
30 listener.OnNext(m.credentials)
31
32
33 go func() {
34 for creds := range m.updates {
35 listener.OnNext(creds)
36 }
37 }()
38
39 return m.credentials, func() error {
40 close(m.updates)
41 return nil
42 }, nil
43 }
44
45 func TestStreamingCredentialsProvider(t *testing.T) {
46 t.Run("successful subscription", func(t *testing.T) {
47 initialCreds := NewBasicCredentials("user1", "pass1")
48 provider := newMockStreamingProvider(initialCreds)
49
50 var receivedCreds []Credentials
51 var receivedErrors []error
52 var mu sync.Mutex
53
54 listener := NewReAuthCredentialsListener(
55 func(creds Credentials) error {
56 mu.Lock()
57 receivedCreds = append(receivedCreds, creds)
58 mu.Unlock()
59 return nil
60 },
61 func(err error) {
62 receivedErrors = append(receivedErrors, err)
63 },
64 )
65
66 creds, cancel, err := provider.Subscribe(listener)
67 if err != nil {
68 t.Fatalf("unexpected error: %v", err)
69 }
70 if cancel == nil {
71 t.Fatal("expected cancel function to be non-nil")
72 }
73 if creds != initialCreds {
74 t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
75 }
76 if len(receivedCreds) != 1 {
77 t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
78 }
79 if receivedCreds[0] != initialCreds {
80 t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
81 }
82 if len(receivedErrors) != 0 {
83 t.Fatalf("expected no errors, got %d", len(receivedErrors))
84 }
85
86
87 newCreds := NewBasicCredentials("user2", "pass2")
88 provider.updates <- newCreds
89
90
91 time.Sleep(100 * time.Millisecond)
92 mu.Lock()
93 if len(receivedCreds) != 2 {
94 t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
95 }
96 if receivedCreds[1] != newCreds {
97 t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
98 }
99 mu.Unlock()
100
101
102 if err := cancel(); err != nil {
103 t.Fatalf("unexpected error cancelling subscription: %v", err)
104 }
105 })
106
107 t.Run("subscription error", func(t *testing.T) {
108 provider := &mockStreamingProvider{
109 err: errors.New("subscription failed"),
110 }
111
112 var receivedCreds []Credentials
113 var receivedErrors []error
114
115 listener := NewReAuthCredentialsListener(
116 func(creds Credentials) error {
117 receivedCreds = append(receivedCreds, creds)
118 return nil
119 },
120 func(err error) {
121 receivedErrors = append(receivedErrors, err)
122 },
123 )
124
125 creds, cancel, err := provider.Subscribe(listener)
126 if err == nil {
127 t.Fatal("expected error, got nil")
128 }
129 if cancel != nil {
130 t.Fatal("expected cancel function to be nil")
131 }
132 if creds != nil {
133 t.Fatalf("expected nil credentials, got %v", creds)
134 }
135 if len(receivedCreds) != 0 {
136 t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
137 }
138 if len(receivedErrors) != 0 {
139 t.Fatalf("expected no errors, got %d", len(receivedErrors))
140 }
141 })
142
143 t.Run("re-auth error", func(t *testing.T) {
144 initialCreds := NewBasicCredentials("user1", "pass1")
145 provider := newMockStreamingProvider(initialCreds)
146
147 reauthErr := errors.New("re-auth failed")
148 var receivedErrors []error
149
150 listener := NewReAuthCredentialsListener(
151 func(creds Credentials) error {
152 return reauthErr
153 },
154 func(err error) {
155 receivedErrors = append(receivedErrors, err)
156 },
157 )
158
159 creds, cancel, err := provider.Subscribe(listener)
160 if err != nil {
161 t.Fatalf("unexpected error: %v", err)
162 }
163 if cancel == nil {
164 t.Fatal("expected cancel function to be non-nil")
165 }
166 if creds != initialCreds {
167 t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
168 }
169 if len(receivedErrors) != 1 {
170 t.Fatalf("expected 1 error, got %d", len(receivedErrors))
171 }
172 if receivedErrors[0] != reauthErr {
173 t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
174 }
175
176 if err := cancel(); err != nil {
177 t.Fatalf("unexpected error cancelling subscription: %v", err)
178 }
179 })
180 }
181
182 func TestBasicCredentials(t *testing.T) {
183 tests := []struct {
184 name string
185 username string
186 password string
187 expectedUser string
188 expectedPass string
189 expectedRaw string
190 }{
191 {
192 name: "basic auth",
193 username: "user1",
194 password: "pass1",
195 expectedUser: "user1",
196 expectedPass: "pass1",
197 expectedRaw: "user1:pass1",
198 },
199 {
200 name: "empty username",
201 username: "",
202 password: "pass1",
203 expectedUser: "",
204 expectedPass: "pass1",
205 expectedRaw: ":pass1",
206 },
207 {
208 name: "empty password",
209 username: "user1",
210 password: "",
211 expectedUser: "user1",
212 expectedPass: "",
213 expectedRaw: "user1:",
214 },
215 {
216 name: "both username and password empty",
217 username: "",
218 password: "",
219 expectedUser: "",
220 expectedPass: "",
221 expectedRaw: ":",
222 },
223 {
224 name: "special characters",
225 username: "user:1",
226 password: "pa:ss@!#",
227 expectedUser: "user:1",
228 expectedPass: "pa:ss@!#",
229 expectedRaw: "user:1:pa:ss@!#",
230 },
231 {
232 name: "unicode characters",
233 username: "ユーザー",
234 password: "密碼123",
235 expectedUser: "ユーザー",
236 expectedPass: "密碼123",
237 expectedRaw: "ユーザー:密碼123",
238 },
239 {
240 name: "long credentials",
241 username: strings.Repeat("u", 1000),
242 password: strings.Repeat("p", 1000),
243 expectedUser: strings.Repeat("u", 1000),
244 expectedPass: strings.Repeat("p", 1000),
245 expectedRaw: strings.Repeat("u", 1000) + ":" + strings.Repeat("p", 1000),
246 },
247 }
248
249 for _, tt := range tests {
250 t.Run(tt.name, func(t *testing.T) {
251 creds := NewBasicCredentials(tt.username, tt.password)
252
253 user, pass := creds.BasicAuth()
254 if user != tt.expectedUser {
255 t.Errorf("BasicAuth() username = %q; want %q", user, tt.expectedUser)
256 }
257 if pass != tt.expectedPass {
258 t.Errorf("BasicAuth() password = %q; want %q", pass, tt.expectedPass)
259 }
260
261 raw := creds.RawCredentials()
262 if raw != tt.expectedRaw {
263 t.Errorf("RawCredentials() = %q; want %q", raw, tt.expectedRaw)
264 }
265 })
266 }
267 }
268
269 func TestReAuthCredentialsListener(t *testing.T) {
270 t.Run("successful re-auth", func(t *testing.T) {
271 var reAuthCalled bool
272 var onErrCalled bool
273 var receivedCreds Credentials
274
275 listener := NewReAuthCredentialsListener(
276 func(creds Credentials) error {
277 reAuthCalled = true
278 receivedCreds = creds
279 return nil
280 },
281 func(err error) {
282 onErrCalled = true
283 },
284 )
285
286 creds := NewBasicCredentials("user1", "pass1")
287 listener.OnNext(creds)
288
289 if !reAuthCalled {
290 t.Fatal("expected reAuth to be called")
291 }
292 if onErrCalled {
293 t.Fatal("expected onErr not to be called")
294 }
295 if receivedCreds != creds {
296 t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
297 }
298 })
299
300 t.Run("re-auth error", func(t *testing.T) {
301 var reAuthCalled bool
302 var onErrCalled bool
303 var receivedErr error
304 expectedErr := errors.New("re-auth failed")
305
306 listener := NewReAuthCredentialsListener(
307 func(creds Credentials) error {
308 reAuthCalled = true
309 return expectedErr
310 },
311 func(err error) {
312 onErrCalled = true
313 receivedErr = err
314 },
315 )
316
317 creds := NewBasicCredentials("user1", "pass1")
318 listener.OnNext(creds)
319
320 if !reAuthCalled {
321 t.Fatal("expected reAuth to be called")
322 }
323 if !onErrCalled {
324 t.Fatal("expected onErr to be called")
325 }
326 if receivedErr != expectedErr {
327 t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
328 }
329 })
330
331 t.Run("on error", func(t *testing.T) {
332 var onErrCalled bool
333 var receivedErr error
334 expectedErr := errors.New("provider error")
335
336 listener := NewReAuthCredentialsListener(
337 func(creds Credentials) error {
338 return nil
339 },
340 func(err error) {
341 onErrCalled = true
342 receivedErr = err
343 },
344 )
345
346 listener.OnError(expectedErr)
347
348 if !onErrCalled {
349 t.Fatal("expected onErr to be called")
350 }
351 if receivedErr != expectedErr {
352 t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
353 }
354 })
355
356 t.Run("nil callbacks", func(t *testing.T) {
357 listener := NewReAuthCredentialsListener(nil, nil)
358
359
360 listener.OnNext(NewBasicCredentials("user1", "pass1"))
361 listener.OnError(errors.New("test error"))
362 })
363 }
364
View as plain text