// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. package runtime import ( "context" "errors" "net/http" "testing" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) const ( tokenValue = "***" accessTokenRespSuccess = `{"access_token": "` + tokenValue + `", "expires_in": 3600}` accessTokenRespShortLived = `{"access_token": "` + tokenValue + `", "expires_in": 0}` scope = "scope" ) type mockCredential struct { getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) } func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { if mc.getTokenImpl != nil { return mc.getTokenImpl(ctx, options) } return azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil } func (mc mockCredential) Do(req *policy.Request) (*http.Response, error) { return nil, nil } func defaultTestPipeline(srv policy.Transporter, scope string) Pipeline { retryOpts := policy.RetryOptions{ MaxRetryDelay: 500 * time.Millisecond, RetryDelay: time.Millisecond, } b := NewBearerTokenPolicy(mockCredential{}, []string{scope}, nil) return NewPipeline( "testmodule", "v0.1.0", PipelineOptions{PerRetry: []policy.Policy{b}}, &policy.ClientOptions{Retry: retryOpts, Transport: srv}, ) } func TestBearerPolicy_SuccessGetToken(t *testing.T) { srv, close := mock.NewTLSServer() defer close() srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) pipeline := defaultTestPipeline(srv, scope) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatal(err) } resp, err := pipeline.Do(req) if err != nil { t.Fatalf("Expected nil error but received one") } const expectedToken = shared.BearerTokenPrefix + tokenValue if token := resp.Request.Header.Get(shared.HeaderAuthorization); token != expectedToken { t.Fatalf("expected token '%s', got '%s'", expectedToken, token) } } func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { srv, close := mock.NewTLSServer() defer close() expectedErr := errors.New("oops") failCredential := mockCredential{} failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { return azcore.AccessToken{}, expectedErr } b := NewBearerTokenPolicy(failCredential, nil, nil) pipeline := newTestPipeline(&policy.ClientOptions{ Transport: srv, Retry: policy.RetryOptions{ RetryDelay: 10 * time.Millisecond, }, PerRetryPolicies: []policy.Policy{b}, }) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatal(err) } resp, err := pipeline.Do(req) if err != expectedErr { t.Fatalf("unexpected error: %v", err) } if resp != nil { t.Fatal("expected nil response") } } func TestBearerTokenPolicy_TokenExpired(t *testing.T) { srv, close := mock.NewTLSServer() defer close() srv.AppendResponse(mock.WithBody([]byte(accessTokenRespShortLived))) srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) pipeline := defaultTestPipeline(srv, scope) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatal(err) } _, err = pipeline.Do(req) if err != nil { t.Fatalf("unexpected error %v", err) } _, err = pipeline.Do(req) if err != nil { t.Fatalf("unexpected error %v", err) } } func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { srv, close := mock.NewTLSServer() defer close() srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) retryOpts := policy.RetryOptions{ // use a negative try timeout to trigger a deadline exceeded error causing GetToken() to fail TryTimeout: -1 * time.Nanosecond, MaxRetryDelay: 500 * time.Millisecond, RetryDelay: 50 * time.Millisecond, MaxRetries: 3, } b := NewBearerTokenPolicy(mockCredential{}, nil, nil) pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []policy.Policy{b}}) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatal(err) } resp, err := pipeline.Do(req) if err == nil { t.Fatal("unexpected nil error") } if resp != nil { t.Fatal("expected nil response") } }