/* Copyright 2016 The Kubernetes Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package gcp import ( "fmt" "io/ioutil" "net/http" "os" "os/exec" "reflect" "strings" "sync" "testing" "time" "golang.org/x/oauth2" ) type fakeOutput struct { args []string output string } var ( wantCmd []string // Output for fakeExec, keyed by command execOutputs = map[string]fakeOutput{ "/default/no/args": { args: []string{}, output: `{ "access_token": "faketoken", "token_expiry": "2016-10-31T22:31:09.123000000Z" }`}, "/default/legacy/args": { args: []string{"arg1", "arg2", "arg3"}, output: `{ "access_token": "faketoken", "token_expiry": "2016-10-31T22:31:09.123000000Z" }`}, "/space in path/customkeys": { args: []string{"can", "haz", "auth"}, output: `{ "token": "faketoken", "token_expiry": { "datetime": "2016-10-31 22:31:09.123" } }`}, "missing/tokenkey/noargs": { args: []string{}, output: `{ "broken": "faketoken", "token_expiry": { "datetime": "2016-10-31 22:31:09.123000000Z" } }`}, "missing/expirykey/legacyargs": { args: []string{"split", "on", "whitespace"}, output: `{ "access_token": "faketoken", "expires": "2016-10-31T22:31:09.123000000Z" }`}, "invalid expiry/timestamp": { args: []string{"foo", "--bar", "--baz=abc,def"}, output: `{ "access_token": "faketoken", "token_expiry": "sometime soon, idk" }`}, "badjson": { args: []string{}, output: `{ "access_token": "faketoken", "token_expiry": "sometime soon, idk" ------ `}, } ) func fakeExec(command string, args ...string) *exec.Cmd { cs := []string{"-test.run=TestHelperProcess", "--", command} cs = append(cs, args...) cmd := exec.Command(os.Args[0], cs...) cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"} return cmd } func TestHelperProcess(t *testing.T) { if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { return } // Strip out the leading args used to exec into this function. gotCmd := os.Args[3] gotArgs := os.Args[4:] output, ok := execOutputs[gotCmd] if !ok { fmt.Fprintf(os.Stdout, "unexpected call cmd=%q args=%v\n", gotCmd, gotArgs) os.Exit(1) } else if !reflect.DeepEqual(output.args, gotArgs) { fmt.Fprintf(os.Stdout, "call cmd=%q got args %v, want: %v\n", gotCmd, gotArgs, output.args) os.Exit(1) } fmt.Fprintf(os.Stdout, output.output) os.Exit(0) } func Test_isCmdTokenSource(t *testing.T) { c1 := map[string]string{"cmd-path": "foo"} if v := isCmdTokenSource(c1); !v { t.Fatalf("cmd-path present in config (%+v), but got %v", c1, v) } c2 := map[string]string{"cmd-args": "foo bar"} if v := isCmdTokenSource(c2); v { t.Fatalf("cmd-path not present in config (%+v), but got %v", c2, v) } } func Test_tokenSource_cmd(t *testing.T) { if _, err := tokenSource(true, map[string]string{}); err == nil { t.Fatalf("expected error, cmd-args not present in config") } c := map[string]string{ "cmd-path": "foo", "cmd-args": "bar"} ts, err := tokenSource(true, c) if err != nil { t.Fatalf("failed to return cmd token source: %+v", err) } if ts == nil { t.Fatal("returned nil token source") } if _, ok := ts.(*commandTokenSource); !ok { t.Fatalf("returned token source type:(%T) expected:(*commandTokenSource)", ts) } } func Test_tokenSource_cmdCannotBeUsedWithScopes(t *testing.T) { c := map[string]string{ "cmd-path": "foo", "scopes": "A,B"} if _, err := tokenSource(true, c); err == nil { t.Fatal("expected error when scopes is used with cmd-path") } } func Test_tokenSource_applicationDefaultCredentials_fails(t *testing.T) { // try to use empty ADC file fakeTokenFile, err := ioutil.TempFile("", "adctoken") if err != nil { t.Fatalf("failed to create fake token file: +%v", err) } fakeTokenFile.Close() defer os.Remove(fakeTokenFile.Name()) os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", fakeTokenFile.Name()) defer os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS") if _, err := tokenSource(false, map[string]string{}); err == nil { t.Fatalf("expected error because specified ADC token file is not a JSON") } } func Test_tokenSource_applicationDefaultCredentials(t *testing.T) { fakeTokenFile, err := ioutil.TempFile("", "adctoken") if err != nil { t.Fatalf("failed to create fake token file: +%v", err) } fakeTokenFile.Close() defer os.Remove(fakeTokenFile.Name()) if err := ioutil.WriteFile(fakeTokenFile.Name(), []byte(`{"type":"service_account"}`), 0600); err != nil { t.Fatalf("failed to write to fake token file: %+v", err) } os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", fakeTokenFile.Name()) defer os.Unsetenv("GOOGLE_APPLICATION_CREDENTIALS") ts, err := tokenSource(false, map[string]string{}) if err != nil { t.Fatalf("failed to get a token source: %+v", err) } if ts == nil { t.Fatal("returned nil token source") } } func Test_parseScopes(t *testing.T) { cases := []struct { in map[string]string out []string }{ { map[string]string{}, []string{ "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/userinfo.email"}, }, { map[string]string{"scopes": ""}, []string{}, }, { map[string]string{"scopes": "A,B,C"}, []string{"A", "B", "C"}, }, } for _, c := range cases { got := parseScopes(c.in) if !reflect.DeepEqual(got, c.out) { t.Errorf("expected=%v, got=%v", c.out, got) } } } func errEquiv(got, want error) bool { if got == want { return true } if got != nil && want != nil { return strings.Contains(got.Error(), want.Error()) } return false } func TestCmdTokenSource(t *testing.T) { execCommand = fakeExec fakeExpiry := time.Date(2016, 10, 31, 22, 31, 9, 123000000, time.UTC) customFmt := "2006-01-02 15:04:05.999999999" tests := []struct { name string gcpConfig map[string]string tok *oauth2.Token newErr, tokenErr error }{ { "default", map[string]string{ "cmd-path": "/default/no/args", }, &oauth2.Token{ AccessToken: "faketoken", TokenType: "Bearer", Expiry: fakeExpiry, }, nil, nil, }, { "default legacy args", map[string]string{ "cmd-path": "/default/legacy/args arg1 arg2 arg3", }, &oauth2.Token{ AccessToken: "faketoken", TokenType: "Bearer", Expiry: fakeExpiry, }, nil, nil, }, { "custom keys", map[string]string{ "cmd-path": "/space in path/customkeys", "cmd-args": "can haz auth", "token-key": "{.token}", "expiry-key": "{.token_expiry.datetime}", "time-fmt": customFmt, }, &oauth2.Token{ AccessToken: "faketoken", TokenType: "Bearer", Expiry: fakeExpiry, }, nil, nil, }, { "missing cmd", map[string]string{ "cmd-path": "", }, nil, fmt.Errorf("missing access token cmd"), nil, }, { "missing token-key", map[string]string{ "cmd-path": "missing/tokenkey/noargs", "token-key": "{.token}", }, nil, nil, fmt.Errorf("error parsing token-key %q", "{.token}"), }, { "missing expiry-key", map[string]string{ "cmd-path": "missing/expirykey/legacyargs split on whitespace", "expiry-key": "{.expiry}", }, nil, nil, fmt.Errorf("error parsing expiry-key %q", "{.expiry}"), }, { "invalid expiry timestamp", map[string]string{ "cmd-path": "invalid expiry/timestamp", "cmd-args": "foo --bar --baz=abc,def", }, &oauth2.Token{ AccessToken: "faketoken", TokenType: "Bearer", Expiry: time.Time{}, }, nil, nil, }, { "bad JSON", map[string]string{ "cmd-path": "badjson", }, nil, nil, fmt.Errorf("invalid character '-' after object key:value pair"), }, } for _, tc := range tests { provider, err := newGCPAuthProvider("", tc.gcpConfig, nil /* persister */) if !errEquiv(err, tc.newErr) { t.Errorf("%q newGCPAuthProvider error: got %v, want %v", tc.name, err, tc.newErr) continue } if err != nil { continue } ts := provider.(*gcpAuthProvider).tokenSource.(*cachedTokenSource).source.(*commandTokenSource) wantCmd = append([]string{ts.cmd}, ts.args...) tok, err := ts.Token() if !errEquiv(err, tc.tokenErr) { t.Errorf("%q Token() error: got %v, want %v", tc.name, err, tc.tokenErr) } if !reflect.DeepEqual(tok, tc.tok) { t.Errorf("%q Token() got %v, want %v", tc.name, tok, tc.tok) } } } type fakePersister struct { lk sync.Mutex cache map[string]string } func (f *fakePersister) Persist(cache map[string]string) error { f.lk.Lock() defer f.lk.Unlock() f.cache = map[string]string{} for k, v := range cache { f.cache[k] = v } return nil } func (f *fakePersister) read() map[string]string { ret := map[string]string{} f.lk.Lock() defer f.lk.Unlock() for k, v := range f.cache { ret[k] = v } return ret } type fakeTokenSource struct { token *oauth2.Token err error } func (f *fakeTokenSource) Token() (*oauth2.Token, error) { return f.token, f.err } func TestCachedTokenSource(t *testing.T) { tok := &oauth2.Token{AccessToken: "fakeaccesstoken"} persister := &fakePersister{} source := &fakeTokenSource{ token: tok, err: nil, } cache := map[string]string{ "foo": "bar", "baz": "bazinga", } ts, err := newCachedTokenSource("fakeaccesstoken", "", persister, source, cache) if err != nil { t.Fatal(err) } var wg sync.WaitGroup wg.Add(10) for i := 0; i < 10; i++ { go func() { _, err := ts.Token() if err != nil { t.Errorf("unexpected error: %s", err) } wg.Done() }() } wg.Wait() cache["access-token"] = "fakeaccesstoken" cache["expiry"] = tok.Expiry.Format(time.RFC3339Nano) if got := persister.read(); !reflect.DeepEqual(got, cache) { t.Errorf("got cache %v, want %v", got, cache) } } type MockTransport struct { res *http.Response } func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { return t.res, nil } func Test_cmdTokenSource_roundTrip(t *testing.T) { accessToken := "fakeToken" fakeExpiry := time.Now().Add(time.Hour) fakeExpiryStr := fakeExpiry.Format(time.RFC3339Nano) fs := &fakeTokenSource{ token: &oauth2.Token{ AccessToken: accessToken, Expiry: fakeExpiry, }, } cmdCache := map[string]string{ "cmd-path": "/path/to/tokensource/cmd", "cmd-args": "--output=json", } cmdCacheUpdated := map[string]string{ "cmd-path": "/path/to/tokensource/cmd", "cmd-args": "--output=json", "access-token": accessToken, "expiry": fakeExpiryStr, } simpleCacheUpdated := map[string]string{ "access-token": accessToken, "expiry": fakeExpiryStr, } tests := []struct { name string res http.Response baseCache, expectedCache map[string]string }{ { "Unauthorized", http.Response{StatusCode: http.StatusUnauthorized}, make(map[string]string), make(map[string]string), }, { "Unauthorized, nonempty defaultCache", http.Response{StatusCode: http.StatusUnauthorized}, cmdCache, cmdCache, }, { "Authorized", http.Response{StatusCode: http.StatusOK}, make(map[string]string), simpleCacheUpdated, }, { "Authorized, nonempty defaultCache", http.Response{StatusCode: http.StatusOK}, cmdCache, cmdCacheUpdated, }, } persister := &fakePersister{} req := http.Request{Header: http.Header{}} for _, tc := range tests { cts, err := newCachedTokenSource(accessToken, fakeExpiry.String(), persister, fs, tc.baseCache) if err != nil { t.Fatalf("unexpected error from newCachedTokenSource: %v", err) } authProvider := gcpAuthProvider{cts, persister} fakeTransport := MockTransport{&tc.res} transport := (authProvider.WrapTransport(&fakeTransport)) // call Token to persist/update cache if _, err := cts.Token(); err != nil { t.Fatalf("unexpected error from cachedTokenSource.Token(): %v", err) } transport.RoundTrip(&req) if got := persister.read(); !reflect.DeepEqual(got, tc.expectedCache) { t.Errorf("got cache %v, want %v", got, tc.expectedCache) } } }