/* Copyright 2017 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package azure import ( "encoding/json" "errors" "fmt" "net/http" "strconv" "strings" "sync" "testing" "time" "github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/azure" ) func TestAzureAuthProvider(t *testing.T) { t.Run("validate against invalid configurations", func(t *testing.T) { vectors := []struct { cfg map[string]string expectedError string }{ { cfg: map[string]string{ cfgClientID: "foo", cfgApiserverID: "foo", cfgTenantID: "foo", cfgConfigMode: "-1", }, expectedError: "config-mode:-1 is not a valid mode", }, { cfg: map[string]string{ cfgClientID: "foo", cfgApiserverID: "foo", cfgTenantID: "foo", cfgConfigMode: "2", }, expectedError: "config-mode:2 is not a valid mode", }, { cfg: map[string]string{ cfgClientID: "foo", cfgApiserverID: "foo", cfgTenantID: "foo", cfgConfigMode: "foo", }, expectedError: "failed to parse config-mode, error: strconv.Atoi: parsing \"foo\": invalid syntax", }, } for _, v := range vectors { persister := &fakePersister{} _, err := newAzureAuthProvider("", v.cfg, persister) if !strings.Contains(err.Error(), v.expectedError) { t.Errorf("cfg %v should fail with message containing '%s'. actual: '%s'", v.cfg, v.expectedError, err) } } }) t.Run("it should return non-nil provider in happy cases", func(t *testing.T) { vectors := []struct { cfg map[string]string expectedConfigMode configMode }{ { cfg: map[string]string{ cfgClientID: "foo", cfgApiserverID: "foo", cfgTenantID: "foo", }, expectedConfigMode: configModeDefault, }, { cfg: map[string]string{ cfgClientID: "foo", cfgApiserverID: "foo", cfgTenantID: "foo", cfgConfigMode: "0", }, expectedConfigMode: configModeDefault, }, { cfg: map[string]string{ cfgClientID: "foo", cfgApiserverID: "foo", cfgTenantID: "foo", cfgConfigMode: "1", }, expectedConfigMode: configModeOmitSPNPrefix, }, } for _, v := range vectors { persister := &fakePersister{} provider, err := newAzureAuthProvider("", v.cfg, persister) if err != nil { t.Errorf("newAzureAuthProvider should not fail with '%s'", err) } if provider == nil { t.Fatalf("newAzureAuthProvider should return non-nil provider") } azureProvider := provider.(*azureAuthProvider) if azureProvider == nil { t.Fatalf("newAzureAuthProvider should return an instance of type azureAuthProvider") } ts := azureProvider.tokenSource.(*azureTokenSource) if ts == nil { t.Fatalf("azureAuthProvider should be an instance of azureTokenSource") } if ts.configMode != v.expectedConfigMode { t.Errorf("expected configMode: %d, actual: %d", v.expectedConfigMode, ts.configMode) } } }) } func TestTokenSourceDeviceCode(t *testing.T) { var ( clientID = "clientID" tenantID = "tenantID" apiserverID = "apiserverID" configMode = configModeDefault azureEnv = azure.Environment{} ) t.Run("validate to create azureTokenSourceDeviceCode", func(t *testing.T) { if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeDefault); err != nil { t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err) } if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeOmitSPNPrefix); err != nil { t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err) } _, err := newAzureTokenSourceDeviceCode(azureEnv, "", tenantID, apiserverID, configMode) actual := "client-id is empty" if err.Error() != actual { t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) } _, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, "", apiserverID, configMode) actual = "tenant-id is empty" if err.Error() != actual { t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) } _, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, "", configMode) actual = "apiserver-id is empty" if err.Error() != actual { t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) } }) } func TestAzureTokenSource(t *testing.T) { configModes := []configMode{configModeOmitSPNPrefix, configModeDefault} expectedConfigModes := []string{"1", "0"} for i, configMode := range configModes { t.Run(fmt.Sprintf("validate token from cfg with configMode %v", configMode), func(t *testing.T) { const ( serverID = "fakeServerID" clientID = "fakeClientID" tenantID = "fakeTenantID" accessToken = "fakeToken" environment = "fakeEnvironment" refreshToken = "fakeToken" expiresIn = "foo" expiresOn = "foo" ) cfg := map[string]string{ cfgConfigMode: strconv.Itoa(int(configMode)), cfgApiserverID: serverID, cfgClientID: clientID, cfgTenantID: tenantID, cfgEnvironment: environment, cfgAccessToken: accessToken, cfgRefreshToken: refreshToken, cfgExpiresIn: expiresIn, cfgExpiresOn: expiresOn, } fakeSource := fakeTokenSource{token: newFakeAzureToken("fakeToken", time.Now().Add(3600*time.Second))} persiter := &fakePersister{cache: make(map[string]string)} tokenCache := newAzureTokenCache() tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter) azTokenSource := tokenSource.(*azureTokenSource) token, err := azTokenSource.retrieveTokenFromCfg() if err != nil { t.Errorf("failed to retrieve the token form cfg: %s", err) } if token.apiserverID != serverID { t.Errorf("expecting token.apiserverID: %s, actual: %s", serverID, token.apiserverID) } if token.clientID != clientID { t.Errorf("expecting token.clientID: %s, actual: %s", clientID, token.clientID) } if token.tenantID != tenantID { t.Errorf("expecting token.tenantID: %s, actual: %s", tenantID, token.tenantID) } expectedAudience := serverID if configMode == configModeDefault { expectedAudience = fmt.Sprintf("spn:%s", serverID) } if token.token.Resource != expectedAudience { t.Errorf("expecting adal token.Resource: %s, actual: %s", expectedAudience, token.token.Resource) } }) t.Run("validate token against cache", func(t *testing.T) { fakeAccessToken := "fake token 1" fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))} cfg := make(map[string]string) persiter := &fakePersister{cache: make(map[string]string)} tokenCache := newAzureTokenCache() tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter) token, err := tokenSource.Token() if err != nil { t.Errorf("failed to retrieve the token form cache: %v", err) } wantCacheLen := 1 if len(tokenCache.cache) != wantCacheLen { t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen) } if token != tokenCache.cache[azureTokenKey] { t.Error("Token() returned token != cached token") } wantCfg := token2Cfg(token) wantCfg[cfgConfigMode] = expectedConfigModes[i] persistedCfg := persiter.Cache() wantCfgLen := len(wantCfg) persistedCfgLen := len(persistedCfg) if wantCfgLen != persistedCfgLen { t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen) } for k, v := range persistedCfg { if strings.Compare(v, wantCfg[k]) != 0 { t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k]) } } fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second)) token, err = tokenSource.Token() if err != nil { t.Errorf("failed to retrieve the cached token: %v", err) } if token.token.AccessToken != fakeAccessToken { t.Errorf("Token() didn't return the cached token") } }) } } func TestAzureTokenSourceScenarios(t *testing.T) { expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second)) extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second)) fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second)) wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second)) tests := []struct { name string sourceToken *azureToken refreshToken *azureToken cachedToken *azureToken configToken *azureToken expectToken *azureToken tokenErr error refreshErr error expectErr string tokenCalls uint refreshCalls uint persistCalls uint }{ { name: "new config", sourceToken: fakeToken, expectToken: fakeToken, tokenCalls: 1, persistCalls: 1, }, { name: "load token from cache", sourceToken: wrongToken, cachedToken: fakeToken, configToken: wrongToken, expectToken: fakeToken, }, { name: "load token from config", sourceToken: wrongToken, configToken: fakeToken, expectToken: fakeToken, }, { name: "cached token timeout, extend success, config token should never load", cachedToken: expiredToken, refreshToken: extendedToken, configToken: wrongToken, expectToken: extendedToken, refreshCalls: 1, persistCalls: 1, }, { name: "config token timeout, extend failure, acquire new token", configToken: expiredToken, refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"}, sourceToken: fakeToken, expectToken: fakeToken, refreshCalls: 1, tokenCalls: 1, persistCalls: 1, }, { name: "extend failure with fmt.Errorf nested tokenRefreshError", configToken: expiredToken, refreshErr: fmt.Errorf("refreshing token: %w", fakeTokenRefreshError{message: "nested FakeError happened when refreshing"}), sourceToken: fakeToken, expectToken: fakeToken, refreshCalls: 1, tokenCalls: 1, persistCalls: 1, }, { name: "unexpected error when extend", configToken: expiredToken, refreshErr: errors.New("unexpected refresh error"), sourceToken: fakeToken, expectErr: "unexpected refresh error", refreshCalls: 1, }, { name: "token error", tokenErr: errors.New("tokenerr"), expectErr: "tokenerr", tokenCalls: 1, }, { name: "Token() got expired token", sourceToken: expiredToken, expectErr: "newly acquired token is expired", tokenCalls: 1, }, { name: "Token() got nil but no error", sourceToken: nil, expectErr: "unable to acquire token", tokenCalls: 1, }, } for _, tc := range tests { configModes := []configMode{configModeOmitSPNPrefix, configModeDefault} for _, configMode := range configModes { t.Run(fmt.Sprintf("%s with configMode: %v", tc.name, configMode), func(t *testing.T) { persister := newFakePersister() cfg := map[string]string{ cfgConfigMode: strconv.Itoa(int(configMode)), } if tc.configToken != nil { cfg = token2Cfg(tc.configToken) } tokenCache := newAzureTokenCache() if tc.cachedToken != nil { tokenCache.setToken(azureTokenKey, tc.cachedToken) } fakeSource := fakeTokenSource{ token: tc.sourceToken, tokenErr: tc.tokenErr, refreshToken: tc.refreshToken, refreshErr: tc.refreshErr, } tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister) token, err := tokenSource.Token() if token != nil && fakeSource.token != nil && token.apiserverID != fakeSource.token.apiserverID { t.Errorf("expecting apiservierID: %s, got: %s", fakeSource.token.apiserverID, token.apiserverID) } if fakeSource.tokenCalls != tc.tokenCalls { t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls) } if fakeSource.refreshCalls != tc.refreshCalls { t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls) } if persister.calls != tc.persistCalls { t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls) } if tc.expectErr != "" { if !strings.Contains(err.Error(), tc.expectErr) { t.Errorf("expecting error %v, got %v", tc.expectErr, err) } if token != nil { t.Errorf("token should be nil in err situation, got %v", token) } } else { if err != nil { t.Fatalf("error should be nil, got %v", err) } if token.token.AccessToken != tc.expectToken.token.AccessToken { t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken) } } }) } } } type fakePersister struct { lock sync.Mutex cache map[string]string calls uint } func newFakePersister() fakePersister { return fakePersister{cache: make(map[string]string), calls: 0} } func (p *fakePersister) Persist(cache map[string]string) error { p.lock.Lock() defer p.lock.Unlock() p.calls++ p.cache = map[string]string{} for k, v := range cache { p.cache[k] = v } return nil } func (p *fakePersister) Cache() map[string]string { ret := map[string]string{} p.lock.Lock() defer p.lock.Unlock() for k, v := range p.cache { ret[k] = v } return ret } // a simple token source simply always returns the token property type fakeTokenSource struct { token *azureToken tokenCalls uint tokenErr error refreshToken *azureToken refreshCalls uint refreshErr error } func (ts *fakeTokenSource) Token() (*azureToken, error) { ts.tokenCalls++ return ts.token, ts.tokenErr } func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) { ts.refreshCalls++ return ts.refreshToken, ts.refreshErr } func token2Cfg(token *azureToken) map[string]string { cfg := make(map[string]string) cfg[cfgAccessToken] = token.token.AccessToken cfg[cfgRefreshToken] = token.token.RefreshToken cfg[cfgEnvironment] = token.environment cfg[cfgClientID] = token.clientID cfg[cfgTenantID] = token.tenantID cfg[cfgApiserverID] = token.apiserverID cfg[cfgExpiresIn] = string(token.token.ExpiresIn) cfg[cfgExpiresOn] = string(token.token.ExpiresOn) return cfg } func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken { return &azureToken{ token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)), environment: "testenv", clientID: "fake", tenantID: "fake", apiserverID: "fake", } } func newFakeADALToken(accessToken string, expiresOn string) adal.Token { return adal.Token{ AccessToken: accessToken, RefreshToken: "fake", ExpiresIn: "3600", ExpiresOn: json.Number(expiresOn), NotBefore: json.Number(expiresOn), Resource: "fake", Type: "fake", } } // copied from go-autorest/adal type fakeTokenRefreshError struct { message string resp *http.Response } // Error implements the error interface which is part of the TokenRefreshError interface. func (tre fakeTokenRefreshError) Error() string { return tre.message } // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation. func (tre fakeTokenRefreshError) Response() *http.Response { return tre.resp }