/* * * Copyright 2020 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package sts import ( "bytes" "context" "crypto/x509" "encoding/json" "errors" "fmt" "io/ioutil" "net/http" "net/http/httputil" "strings" "testing" "time" "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials" icredentials "google.golang.org/grpc/internal/credentials" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" ) const ( requestedTokenType = "urn:ietf:params:oauth:token-type:access-token" actorTokenPath = "/var/run/secrets/token.jwt" actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token" actorTokenContents = "actorToken.jwt.contents" accessTokenContents = "access_token" subjectTokenPath = "/var/run/secrets/token.jwt" subjectTokenType = "urn:ietf:params:oauth:token-type:id_token" subjectTokenContents = "subjectToken.jwt.contents" serviceURI = "http://localhost" exampleResource = "https://backend.example.com/api" exampleAudience = "example-backend-service" testScope = "https://www.googleapis.com/auth/monitoring" defaultTestTimeout = 1 * time.Second defaultTestShortTimeout = 10 * time.Millisecond ) var ( goodOptions = Options{ TokenExchangeServiceURI: serviceURI, Audience: exampleAudience, RequestedTokenType: requestedTokenType, SubjectTokenPath: subjectTokenPath, SubjectTokenType: subjectTokenType, } goodRequestParams = &requestParameters{ GrantType: tokenExchangeGrantType, Audience: exampleAudience, Scope: defaultCloudPlatformScope, RequestedTokenType: requestedTokenType, SubjectToken: subjectTokenContents, SubjectTokenType: subjectTokenType, } goodMetadata = map[string]string{ "Authorization": fmt.Sprintf("Bearer %s", accessTokenContents), } ) type s struct { grpctest.Tester } func Test(t *testing.T) { grpctest.RunSubTests(t, s{}) } // A struct that implements AuthInfo interface and added to the context passed // to GetRequestMetadata from tests. type testAuthInfo struct { credentials.CommonAuthInfo } func (ta testAuthInfo) AuthType() string { return "testAuthInfo" } func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context { auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}} ri := credentials.RequestInfo{ Method: "testInfo", AuthInfo: auth, } return icredentials.NewRequestInfoContext(ctx, ri) } // errReader implements the io.Reader interface and returns an error from the // Read method. type errReader struct{} func (r errReader) Read(b []byte) (n int, err error) { return 0, errors.New("read error") } // We need a function to construct the response instead of simply declaring it // as a variable since the the response body will be consumed by the // credentials, and therefore we will need a new one everytime. func makeGoodResponse() *http.Response { respJSON, _ := json.Marshal(responseParameters{ AccessToken: accessTokenContents, IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", TokenType: "Bearer", ExpiresIn: 3600, }) respBody := ioutil.NopCloser(bytes.NewReader(respJSON)) return &http.Response{ Status: "200 OK", StatusCode: http.StatusOK, Body: respBody, } } // Overrides the http.Client with a fakeClient which sends a good response. func overrideHTTPClientGood() (*testutils.FakeHTTPClient, func()) { fc := &testutils.FakeHTTPClient{ ReqChan: testutils.NewChannel(), RespChan: testutils.NewChannel(), } fc.RespChan.Send(makeGoodResponse()) origMakeHTTPDoer := makeHTTPDoer makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } return fc, func() { makeHTTPDoer = origMakeHTTPDoer } } // Overrides the http.Client with the provided fakeClient. func overrideHTTPClient(fc *testutils.FakeHTTPClient) func() { origMakeHTTPDoer := makeHTTPDoer makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc } return func() { makeHTTPDoer = origMakeHTTPDoer } } // Overrides the subject token read to return a const which we can compare in // our tests. func overrideSubjectTokenGood() func() { origReadSubjectTokenFrom := readSubjectTokenFrom readSubjectTokenFrom = func(path string) ([]byte, error) { return []byte(subjectTokenContents), nil } return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } } // Overrides the subject token read to always return an error. func overrideSubjectTokenError() func() { origReadSubjectTokenFrom := readSubjectTokenFrom readSubjectTokenFrom = func(path string) ([]byte, error) { return nil, errors.New("error reading subject token") } return func() { readSubjectTokenFrom = origReadSubjectTokenFrom } } // Overrides the actor token read to return a const which we can compare in // our tests. func overrideActorTokenGood() func() { origReadActorTokenFrom := readActorTokenFrom readActorTokenFrom = func(path string) ([]byte, error) { return []byte(actorTokenContents), nil } return func() { readActorTokenFrom = origReadActorTokenFrom } } // Overrides the actor token read to always return an error. func overrideActorTokenError() func() { origReadActorTokenFrom := readActorTokenFrom readActorTokenFrom = func(path string) ([]byte, error) { return nil, errors.New("error reading actor token") } return func() { readActorTokenFrom = origReadActorTokenFrom } } // compareRequest compares the http.Request received in the test with the // expected requestParameters specified in wantReqParams. func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error { jsonBody, err := json.Marshal(wantReqParams) if err != nil { return err } wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody)) if err != nil { return fmt.Errorf("failed to create http request: %v", err) } wantReq.Header.Set("Content-Type", "application/json") wantR, err := httputil.DumpRequestOut(wantReq, true) if err != nil { return err } gotR, err := httputil.DumpRequestOut(gotRequest, true) if err != nil { return err } if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" { return fmt.Errorf("sts request diff (-want +got):\n%s", diff) } return nil } // receiveAndCompareRequest waits for a request to be sent out by the // credentials implementation using the fakeHTTPClient and compares it to an // expected goodRequest. This is expected to be called in a separate goroutine // by the tests. So, any errors encountered are pushed to an error channel // which is monitored by the test. func receiveAndCompareRequest(ReqChan *testutils.Channel, errCh chan error) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() val, err := ReqChan.Receive(ctx) if err != nil { errCh <- err return } req := val.(*http.Request) if err := compareRequest(req, goodRequestParams); err != nil { errCh <- err return } errCh <- nil } // TestGetRequestMetadataSuccess verifies the successful case of sending an // token exchange request and processing the response. func (s) TestGetRequestMetadataSuccess(t *testing.T) { defer overrideSubjectTokenGood()() fc, cancel := overrideHTTPClientGood() defer cancel() creds, err := NewCredentials(goodOptions) if err != nil { t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) } errCh := make(chan error, 1) go receiveAndCompareRequest(fc.ReqChan, errCh) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "") if err != nil { t.Fatalf("creds.GetRequestMetadata() = %v", err) } if !cmp.Equal(gotMetadata, goodMetadata) { t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) } if err := <-errCh; err != nil { t.Fatal(err) } // Make another call to get request metadata and this should return contents // from the cache. This will fail if the credentials tries to send a fresh // request here since we have not configured our fakeClient to return any // response on retries. gotMetadata, err = creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "") if err != nil { t.Fatalf("creds.GetRequestMetadata() = %v", err) } if !cmp.Equal(gotMetadata, goodMetadata) { t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) } } // TestGetRequestMetadataBadSecurityLevel verifies the case where the // securityLevel specified in the context passed to GetRequestMetadata is not // sufficient. func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) { defer overrideSubjectTokenGood()() creds, err := NewCredentials(goodOptions) if err != nil { t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.IntegrityOnly), "") if err == nil { t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata) } } // TestGetRequestMetadataCacheExpiry verifies the case where the cached access // token has expired, and the credentials implementation will have to send a // fresh token exchange request. func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) { const expiresInSecs = 1 defer overrideSubjectTokenGood()() fc := &testutils.FakeHTTPClient{ ReqChan: testutils.NewChannel(), RespChan: testutils.NewChannel(), } defer overrideHTTPClient(fc)() creds, err := NewCredentials(goodOptions) if err != nil { t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) } // The fakeClient is configured to return an access_token with a one second // expiry. So, in the second iteration, the credentials will find the cache // entry, but that would have expired, and therefore we expect it to send // out a fresh request. for i := 0; i < 2; i++ { errCh := make(chan error, 1) go receiveAndCompareRequest(fc.ReqChan, errCh) respJSON, _ := json.Marshal(responseParameters{ AccessToken: accessTokenContents, IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", TokenType: "Bearer", ExpiresIn: expiresInSecs, }) respBody := ioutil.NopCloser(bytes.NewReader(respJSON)) resp := &http.Response{ Status: "200 OK", StatusCode: http.StatusOK, Body: respBody, } fc.RespChan.Send(resp) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "") if err != nil { t.Fatalf("creds.GetRequestMetadata() = %v", err) } if !cmp.Equal(gotMetadata, goodMetadata) { t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata) } if err := <-errCh; err != nil { t.Fatal(err) } time.Sleep(expiresInSecs * time.Second) } } // TestGetRequestMetadataBadResponses verifies the scenario where the token // exchange server returns bad responses. func (s) TestGetRequestMetadataBadResponses(t *testing.T) { tests := []struct { name string response *http.Response }{ { name: "bad JSON", response: &http.Response{ Status: "200 OK", StatusCode: http.StatusOK, Body: ioutil.NopCloser(strings.NewReader("not JSON")), }, }, { name: "no access token", response: &http.Response{ Status: "200 OK", StatusCode: http.StatusOK, Body: ioutil.NopCloser(strings.NewReader("{}")), }, }, } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() for _, test := range tests { t.Run(test.name, func(t *testing.T) { defer overrideSubjectTokenGood()() fc := &testutils.FakeHTTPClient{ ReqChan: testutils.NewChannel(), RespChan: testutils.NewChannel(), } defer overrideHTTPClient(fc)() creds, err := NewCredentials(goodOptions) if err != nil { t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) } errCh := make(chan error, 1) go receiveAndCompareRequest(fc.ReqChan, errCh) fc.RespChan.Send(test.response) if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil { t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") } if err := <-errCh; err != nil { t.Fatal(err) } }) } } // TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the // attempt to read the subjectToken fails. func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) { defer overrideSubjectTokenError()() fc, cancel := overrideHTTPClientGood() defer cancel() creds, err := NewCredentials(goodOptions) if err != nil { t.Fatalf("NewCredentials(%v) = %v", goodOptions, err) } errCh := make(chan error, 1) go func() { ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer cancel() if _, err := fc.ReqChan.Receive(ctx); err != context.DeadlineExceeded { errCh <- err return } errCh <- nil }() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil { t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail") } if err := <-errCh; err != nil { t.Fatal(err) } } func (s) TestNewCredentials(t *testing.T) { tests := []struct { name string opts Options errSystemRoots bool wantErr bool }{ { name: "invalid options - empty subjectTokenPath", opts: Options{ TokenExchangeServiceURI: serviceURI, }, wantErr: true, }, { name: "invalid system root certs", opts: goodOptions, errSystemRoots: true, wantErr: true, }, { name: "good case", opts: goodOptions, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { if test.errSystemRoots { oldSystemRoots := loadSystemCertPool loadSystemCertPool = func() (*x509.CertPool, error) { return nil, errors.New("failed to load system cert pool") } defer func() { loadSystemCertPool = oldSystemRoots }() } creds, err := NewCredentials(test.opts) if (err != nil) != test.wantErr { t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr) } if err == nil { if !creds.RequireTransportSecurity() { t.Errorf("creds.RequireTransportSecurity() returned false") } } }) } } func (s) TestValidateOptions(t *testing.T) { tests := []struct { name string opts Options wantErrPrefix string }{ { name: "empty token exchange service URI", opts: Options{}, wantErrPrefix: "empty token_exchange_service_uri in options", }, { name: "invalid URI", opts: Options{ TokenExchangeServiceURI: "\tI'm a bad URI\n", }, wantErrPrefix: "invalid control character in URL", }, { name: "unsupported scheme", opts: Options{ TokenExchangeServiceURI: "unix:///path/to/socket", }, wantErrPrefix: "scheme is not supported", }, { name: "empty subjectTokenPath", opts: Options{ TokenExchangeServiceURI: serviceURI, }, wantErrPrefix: "required field SubjectTokenPath is not specified", }, { name: "empty subjectTokenType", opts: Options{ TokenExchangeServiceURI: serviceURI, SubjectTokenPath: subjectTokenPath, }, wantErrPrefix: "required field SubjectTokenType is not specified", }, { name: "good options", opts: goodOptions, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { err := validateOptions(test.opts) if (err != nil) != (test.wantErrPrefix != "") { t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix) } if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) { t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix) } }) } } func (s) TestConstructRequest(t *testing.T) { tests := []struct { name string opts Options subjectTokenReadErr bool actorTokenReadErr bool wantReqParams *requestParameters wantErr bool }{ { name: "subject token read failure", subjectTokenReadErr: true, opts: goodOptions, wantErr: true, }, { name: "actor token read failure", actorTokenReadErr: true, opts: Options{ TokenExchangeServiceURI: serviceURI, Audience: exampleAudience, RequestedTokenType: requestedTokenType, SubjectTokenPath: subjectTokenPath, SubjectTokenType: subjectTokenType, ActorTokenPath: actorTokenPath, ActorTokenType: actorTokenType, }, wantErr: true, }, { name: "default cloud platform scope", opts: goodOptions, wantReqParams: goodRequestParams, }, { name: "all good", opts: Options{ TokenExchangeServiceURI: serviceURI, Resource: exampleResource, Audience: exampleAudience, Scope: testScope, RequestedTokenType: requestedTokenType, SubjectTokenPath: subjectTokenPath, SubjectTokenType: subjectTokenType, ActorTokenPath: actorTokenPath, ActorTokenType: actorTokenType, }, wantReqParams: &requestParameters{ GrantType: tokenExchangeGrantType, Resource: exampleResource, Audience: exampleAudience, Scope: testScope, RequestedTokenType: requestedTokenType, SubjectToken: subjectTokenContents, SubjectTokenType: subjectTokenType, ActorToken: actorTokenContents, ActorTokenType: actorTokenType, }, }, } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() for _, test := range tests { t.Run(test.name, func(t *testing.T) { if test.subjectTokenReadErr { defer overrideSubjectTokenError()() } else { defer overrideSubjectTokenGood()() } if test.actorTokenReadErr { defer overrideActorTokenError()() } else { defer overrideActorTokenGood()() } gotRequest, err := constructRequest(ctx, test.opts) if (err != nil) != test.wantErr { t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr) } if test.wantErr { return } if err := compareRequest(gotRequest, test.wantReqParams); err != nil { t.Fatal(err) } }) } } func (s) TestSendRequest(t *testing.T) { defer overrideSubjectTokenGood()() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() req, err := constructRequest(ctx, goodOptions) if err != nil { t.Fatal(err) } tests := []struct { name string resp *http.Response respErr error wantErr bool }{ { name: "client error", respErr: errors.New("http.Client.Do failed"), wantErr: true, }, { name: "bad response body", resp: &http.Response{ Status: "200 OK", StatusCode: http.StatusOK, Body: ioutil.NopCloser(errReader{}), }, wantErr: true, }, { name: "nonOK status code", resp: &http.Response{ Status: "400 BadRequest", StatusCode: http.StatusBadRequest, Body: ioutil.NopCloser(strings.NewReader("")), }, wantErr: true, }, { name: "good case", resp: makeGoodResponse(), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { client := &testutils.FakeHTTPClient{ ReqChan: testutils.NewChannel(), RespChan: testutils.NewChannel(), Err: test.respErr, } client.RespChan.Send(test.resp) _, err := sendRequest(client, req) if (err != nil) != test.wantErr { t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr) } }) } } func (s) TestTokenInfoFromResponse(t *testing.T) { noAccessToken, _ := json.Marshal(responseParameters{ IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token", TokenType: "Bearer", ExpiresIn: 3600, }) goodResponse, _ := json.Marshal(responseParameters{ IssuedTokenType: requestedTokenType, AccessToken: accessTokenContents, TokenType: "Bearer", ExpiresIn: 3600, }) tests := []struct { name string respBody []byte wantTokenInfo *tokenInfo wantErr bool }{ { name: "bad JSON", respBody: []byte("not JSON"), wantErr: true, }, { name: "empty response", respBody: []byte(""), wantErr: true, }, { name: "non-empty response with no access token", respBody: noAccessToken, wantErr: true, }, { name: "good response", respBody: goodResponse, wantTokenInfo: &tokenInfo{ tokenType: "Bearer", token: accessTokenContents, }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { gotTokenInfo, err := tokenInfoFromResponse(test.respBody) if (err != nil) != test.wantErr { t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr) } if test.wantErr { return } // Can't do a cmp.Equal on the whole struct since the expiryField // is populated based on time.Now(). if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token { t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo) } }) } }