...

Source file src/github.com/redis/go-redis/v9/auth/auth_test.go

Documentation: github.com/redis/go-redis/v9/auth

     1  package auth
     2  
     3  import (
     4  	"errors"
     5  	"strings"
     6  	"sync"
     7  	"testing"
     8  	"time"
     9  )
    10  
    11  type mockStreamingProvider struct {
    12  	credentials Credentials
    13  	err         error
    14  	updates     chan Credentials
    15  }
    16  
    17  func newMockStreamingProvider(initialCreds Credentials) *mockStreamingProvider {
    18  	return &mockStreamingProvider{
    19  		credentials: initialCreds,
    20  		updates:     make(chan Credentials, 10),
    21  	}
    22  }
    23  
    24  func (m *mockStreamingProvider) Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) {
    25  	if m.err != nil {
    26  		return nil, nil, m.err
    27  	}
    28  
    29  	// Send initial credentials
    30  	listener.OnNext(m.credentials)
    31  
    32  	// Start goroutine to handle updates
    33  	go func() {
    34  		for creds := range m.updates {
    35  			listener.OnNext(creds)
    36  		}
    37  	}()
    38  
    39  	return m.credentials, func() error {
    40  		close(m.updates)
    41  		return nil
    42  	}, nil
    43  }
    44  
    45  func TestStreamingCredentialsProvider(t *testing.T) {
    46  	t.Run("successful subscription", func(t *testing.T) {
    47  		initialCreds := NewBasicCredentials("user1", "pass1")
    48  		provider := newMockStreamingProvider(initialCreds)
    49  
    50  		var receivedCreds []Credentials
    51  		var receivedErrors []error
    52  		var mu sync.Mutex
    53  
    54  		listener := NewReAuthCredentialsListener(
    55  			func(creds Credentials) error {
    56  				mu.Lock()
    57  				receivedCreds = append(receivedCreds, creds)
    58  				mu.Unlock()
    59  				return nil
    60  			},
    61  			func(err error) {
    62  				receivedErrors = append(receivedErrors, err)
    63  			},
    64  		)
    65  
    66  		creds, cancel, err := provider.Subscribe(listener)
    67  		if err != nil {
    68  			t.Fatalf("unexpected error: %v", err)
    69  		}
    70  		if cancel == nil {
    71  			t.Fatal("expected cancel function to be non-nil")
    72  		}
    73  		if creds != initialCreds {
    74  			t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
    75  		}
    76  		if len(receivedCreds) != 1 {
    77  			t.Fatalf("expected 1 received credential, got %d", len(receivedCreds))
    78  		}
    79  		if receivedCreds[0] != initialCreds {
    80  			t.Fatalf("expected received credential %v, got %v", initialCreds, receivedCreds[0])
    81  		}
    82  		if len(receivedErrors) != 0 {
    83  			t.Fatalf("expected no errors, got %d", len(receivedErrors))
    84  		}
    85  
    86  		// Send an update
    87  		newCreds := NewBasicCredentials("user2", "pass2")
    88  		provider.updates <- newCreds
    89  
    90  		// Wait for update to be processed
    91  		time.Sleep(100 * time.Millisecond)
    92  		mu.Lock()
    93  		if len(receivedCreds) != 2 {
    94  			t.Fatalf("expected 2 received credentials, got %d", len(receivedCreds))
    95  		}
    96  		if receivedCreds[1] != newCreds {
    97  			t.Fatalf("expected received credential %v, got %v", newCreds, receivedCreds[1])
    98  		}
    99  		mu.Unlock()
   100  
   101  		// Cancel subscription
   102  		if err := cancel(); err != nil {
   103  			t.Fatalf("unexpected error cancelling subscription: %v", err)
   104  		}
   105  	})
   106  
   107  	t.Run("subscription error", func(t *testing.T) {
   108  		provider := &mockStreamingProvider{
   109  			err: errors.New("subscription failed"),
   110  		}
   111  
   112  		var receivedCreds []Credentials
   113  		var receivedErrors []error
   114  
   115  		listener := NewReAuthCredentialsListener(
   116  			func(creds Credentials) error {
   117  				receivedCreds = append(receivedCreds, creds)
   118  				return nil
   119  			},
   120  			func(err error) {
   121  				receivedErrors = append(receivedErrors, err)
   122  			},
   123  		)
   124  
   125  		creds, cancel, err := provider.Subscribe(listener)
   126  		if err == nil {
   127  			t.Fatal("expected error, got nil")
   128  		}
   129  		if cancel != nil {
   130  			t.Fatal("expected cancel function to be nil")
   131  		}
   132  		if creds != nil {
   133  			t.Fatalf("expected nil credentials, got %v", creds)
   134  		}
   135  		if len(receivedCreds) != 0 {
   136  			t.Fatalf("expected no received credentials, got %d", len(receivedCreds))
   137  		}
   138  		if len(receivedErrors) != 0 {
   139  			t.Fatalf("expected no errors, got %d", len(receivedErrors))
   140  		}
   141  	})
   142  
   143  	t.Run("re-auth error", func(t *testing.T) {
   144  		initialCreds := NewBasicCredentials("user1", "pass1")
   145  		provider := newMockStreamingProvider(initialCreds)
   146  
   147  		reauthErr := errors.New("re-auth failed")
   148  		var receivedErrors []error
   149  
   150  		listener := NewReAuthCredentialsListener(
   151  			func(creds Credentials) error {
   152  				return reauthErr
   153  			},
   154  			func(err error) {
   155  				receivedErrors = append(receivedErrors, err)
   156  			},
   157  		)
   158  
   159  		creds, cancel, err := provider.Subscribe(listener)
   160  		if err != nil {
   161  			t.Fatalf("unexpected error: %v", err)
   162  		}
   163  		if cancel == nil {
   164  			t.Fatal("expected cancel function to be non-nil")
   165  		}
   166  		if creds != initialCreds {
   167  			t.Fatalf("expected credentials %v, got %v", initialCreds, creds)
   168  		}
   169  		if len(receivedErrors) != 1 {
   170  			t.Fatalf("expected 1 error, got %d", len(receivedErrors))
   171  		}
   172  		if receivedErrors[0] != reauthErr {
   173  			t.Fatalf("expected error %v, got %v", reauthErr, receivedErrors[0])
   174  		}
   175  
   176  		if err := cancel(); err != nil {
   177  			t.Fatalf("unexpected error cancelling subscription: %v", err)
   178  		}
   179  	})
   180  }
   181  
   182  func TestBasicCredentials(t *testing.T) {
   183  	tests := []struct {
   184  		name         string
   185  		username     string
   186  		password     string
   187  		expectedUser string
   188  		expectedPass string
   189  		expectedRaw  string
   190  	}{
   191  		{
   192  			name:         "basic auth",
   193  			username:     "user1",
   194  			password:     "pass1",
   195  			expectedUser: "user1",
   196  			expectedPass: "pass1",
   197  			expectedRaw:  "user1:pass1",
   198  		},
   199  		{
   200  			name:         "empty username",
   201  			username:     "",
   202  			password:     "pass1",
   203  			expectedUser: "",
   204  			expectedPass: "pass1",
   205  			expectedRaw:  ":pass1",
   206  		},
   207  		{
   208  			name:         "empty password",
   209  			username:     "user1",
   210  			password:     "",
   211  			expectedUser: "user1",
   212  			expectedPass: "",
   213  			expectedRaw:  "user1:",
   214  		},
   215  		{
   216  			name:         "both username and password empty",
   217  			username:     "",
   218  			password:     "",
   219  			expectedUser: "",
   220  			expectedPass: "",
   221  			expectedRaw:  ":",
   222  		},
   223  		{
   224  			name:         "special characters",
   225  			username:     "user:1",
   226  			password:     "pa:ss@!#",
   227  			expectedUser: "user:1",
   228  			expectedPass: "pa:ss@!#",
   229  			expectedRaw:  "user:1:pa:ss@!#",
   230  		},
   231  		{
   232  			name:         "unicode characters",
   233  			username:     "ユーザー",
   234  			password:     "密碼123",
   235  			expectedUser: "ユーザー",
   236  			expectedPass: "密碼123",
   237  			expectedRaw:  "ユーザー:密碼123",
   238  		},
   239  		{
   240  			name:         "long credentials",
   241  			username:     strings.Repeat("u", 1000),
   242  			password:     strings.Repeat("p", 1000),
   243  			expectedUser: strings.Repeat("u", 1000),
   244  			expectedPass: strings.Repeat("p", 1000),
   245  			expectedRaw:  strings.Repeat("u", 1000) + ":" + strings.Repeat("p", 1000),
   246  		},
   247  	}
   248  
   249  	for _, tt := range tests {
   250  		t.Run(tt.name, func(t *testing.T) {
   251  			creds := NewBasicCredentials(tt.username, tt.password)
   252  
   253  			user, pass := creds.BasicAuth()
   254  			if user != tt.expectedUser {
   255  				t.Errorf("BasicAuth() username = %q; want %q", user, tt.expectedUser)
   256  			}
   257  			if pass != tt.expectedPass {
   258  				t.Errorf("BasicAuth() password = %q; want %q", pass, tt.expectedPass)
   259  			}
   260  
   261  			raw := creds.RawCredentials()
   262  			if raw != tt.expectedRaw {
   263  				t.Errorf("RawCredentials() = %q; want %q", raw, tt.expectedRaw)
   264  			}
   265  		})
   266  	}
   267  }
   268  
   269  func TestReAuthCredentialsListener(t *testing.T) {
   270  	t.Run("successful re-auth", func(t *testing.T) {
   271  		var reAuthCalled bool
   272  		var onErrCalled bool
   273  		var receivedCreds Credentials
   274  
   275  		listener := NewReAuthCredentialsListener(
   276  			func(creds Credentials) error {
   277  				reAuthCalled = true
   278  				receivedCreds = creds
   279  				return nil
   280  			},
   281  			func(err error) {
   282  				onErrCalled = true
   283  			},
   284  		)
   285  
   286  		creds := NewBasicCredentials("user1", "pass1")
   287  		listener.OnNext(creds)
   288  
   289  		if !reAuthCalled {
   290  			t.Fatal("expected reAuth to be called")
   291  		}
   292  		if onErrCalled {
   293  			t.Fatal("expected onErr not to be called")
   294  		}
   295  		if receivedCreds != creds {
   296  			t.Fatalf("expected credentials %v, got %v", creds, receivedCreds)
   297  		}
   298  	})
   299  
   300  	t.Run("re-auth error", func(t *testing.T) {
   301  		var reAuthCalled bool
   302  		var onErrCalled bool
   303  		var receivedErr error
   304  		expectedErr := errors.New("re-auth failed")
   305  
   306  		listener := NewReAuthCredentialsListener(
   307  			func(creds Credentials) error {
   308  				reAuthCalled = true
   309  				return expectedErr
   310  			},
   311  			func(err error) {
   312  				onErrCalled = true
   313  				receivedErr = err
   314  			},
   315  		)
   316  
   317  		creds := NewBasicCredentials("user1", "pass1")
   318  		listener.OnNext(creds)
   319  
   320  		if !reAuthCalled {
   321  			t.Fatal("expected reAuth to be called")
   322  		}
   323  		if !onErrCalled {
   324  			t.Fatal("expected onErr to be called")
   325  		}
   326  		if receivedErr != expectedErr {
   327  			t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
   328  		}
   329  	})
   330  
   331  	t.Run("on error", func(t *testing.T) {
   332  		var onErrCalled bool
   333  		var receivedErr error
   334  		expectedErr := errors.New("provider error")
   335  
   336  		listener := NewReAuthCredentialsListener(
   337  			func(creds Credentials) error {
   338  				return nil
   339  			},
   340  			func(err error) {
   341  				onErrCalled = true
   342  				receivedErr = err
   343  			},
   344  		)
   345  
   346  		listener.OnError(expectedErr)
   347  
   348  		if !onErrCalled {
   349  			t.Fatal("expected onErr to be called")
   350  		}
   351  		if receivedErr != expectedErr {
   352  			t.Fatalf("expected error %v, got %v", expectedErr, receivedErr)
   353  		}
   354  	})
   355  
   356  	t.Run("nil callbacks", func(t *testing.T) {
   357  		listener := NewReAuthCredentialsListener(nil, nil)
   358  
   359  		// Should not panic
   360  		listener.OnNext(NewBasicCredentials("user1", "pass1"))
   361  		listener.OnError(errors.New("test error"))
   362  	})
   363  }
   364  

View as plain text