// Copyright 2015 The Prometheus Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // +build go1.8 package config import ( "context" "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "reflect" "strconv" "strings" "sync" "sync/atomic" "testing" "time" yaml "gopkg.in/yaml.v2" ) const ( TLSCAChainPath = "testdata/tls-ca-chain.pem" ServerCertificatePath = "testdata/server.crt" ServerKeyPath = "testdata/server.key" ClientCertificatePath = "testdata/client.crt" ClientKeyNoPassPath = "testdata/client-no-pass.key" InvalidCA = "testdata/client-no-pass.key" WrongClientCertPath = "testdata/self-signed-client.crt" WrongClientKeyPath = "testdata/self-signed-client.key" EmptyFile = "testdata/empty" MissingCA = "missing/ca.crt" MissingCert = "missing/cert.crt" MissingKey = "missing/secret.key" ExpectedMessage = "I'm here to serve you!!!" ExpectedError = "expected error" AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo" AuthorizationCredentialsFile = "testdata/bearer.token" AuthorizationType = "APIKEY" BearerToken = AuthorizationCredentials BearerTokenFile = AuthorizationCredentialsFile MissingBearerTokenFile = "missing/bearer.token" ExpectedBearer = "Bearer " + BearerToken ExpectedAuthenticationCredentials = AuthorizationType + " " + BearerToken ExpectedUsername = "arthurdent" ExpectedPassword = "42" ) var invalidHTTPClientConfigs = []struct { httpClientConfigFile string errMsg string }{ { httpClientConfigFile: "testdata/http.conf.bearer-token-and-file-set.bad.yml", errMsg: "at most one of bearer_token & bearer_token_file must be configured", }, { httpClientConfigFile: "testdata/http.conf.empty.bad.yml", errMsg: "at most one of basic_auth, oauth2, bearer_token & bearer_token_file must be configured", }, { httpClientConfigFile: "testdata/http.conf.basic-auth.too-much.bad.yaml", errMsg: "at most one of basic_auth password & password_file must be configured", }, { httpClientConfigFile: "testdata/http.conf.mix-bearer-and-creds.bad.yaml", errMsg: "authorization is not compatible with bearer_token & bearer_token_file", }, { httpClientConfigFile: "testdata/http.conf.auth-creds-and-file-set.too-much.bad.yaml", errMsg: "at most one of authorization credentials & credentials_file must be configured", }, { httpClientConfigFile: "testdata/http.conf.basic-auth-and-auth-creds.too-much.bad.yaml", errMsg: "at most one of basic_auth, oauth2 & authorization must be configured", }, { httpClientConfigFile: "testdata/http.conf.basic-auth-and-oauth2.too-much.bad.yaml", errMsg: "at most one of basic_auth, oauth2 & authorization must be configured", }, { httpClientConfigFile: "testdata/http.conf.auth-creds-no-basic.bad.yaml", errMsg: `authorization type cannot be set to "basic", use "basic_auth" instead`, }, { httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml", errMsg: "at most one of oauth2 client_secret & client_secret_file must be configured", }, { httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml", errMsg: "oauth2 client_id must be configured", }, { httpClientConfigFile: "testdata/http.conf.oauth2-no-client-secret.bad.yaml", errMsg: "either oauth2 client_secret or client_secret_file must be configured", }, { httpClientConfigFile: "testdata/http.conf.oauth2-no-token-url.bad.yaml", errMsg: "oauth2 token_url must be configured", }, } func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, error) { testServer := httptest.NewUnstartedServer(http.HandlerFunc(handler)) tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath) if err != nil { return nil, fmt.Errorf("Can't read %s", TLSCAChainPath) } serverCertificate, err := tls.LoadX509KeyPair(ServerCertificatePath, ServerKeyPath) if err != nil { return nil, fmt.Errorf("Can't load X509 key pair %s - %s", ServerCertificatePath, ServerKeyPath) } rootCAs := x509.NewCertPool() rootCAs.AppendCertsFromPEM(tlsCAChain) testServer.TLS = &tls.Config{ Certificates: make([]tls.Certificate, 1), RootCAs: rootCAs, ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: rootCAs} testServer.TLS.Certificates[0] = serverCertificate testServer.StartTLS() return testServer, nil } func TestNewClientFromConfig(t *testing.T) { var newClientValidConfig = []struct { clientConfig HTTPClientConfig handler func(w http.ResponseWriter, r *http.Request) }{ { clientConfig: HTTPClientConfig{ TLSConfig: TLSConfig{ CAFile: "", CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: true}, }, handler: func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, ExpectedMessage) }, }, { clientConfig: HTTPClientConfig{ TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, ExpectedMessage) }, }, { clientConfig: HTTPClientConfig{ BearerToken: BearerToken, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { bearer := r.Header.Get("Authorization") if bearer != ExpectedBearer { fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedBearer, bearer) } else { fmt.Fprint(w, ExpectedMessage) } }, }, { clientConfig: HTTPClientConfig{ BearerTokenFile: BearerTokenFile, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { bearer := r.Header.Get("Authorization") if bearer != ExpectedBearer { fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedBearer, bearer) } else { fmt.Fprint(w, ExpectedMessage) } }, }, { clientConfig: HTTPClientConfig{ Authorization: &Authorization{Credentials: BearerToken}, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { bearer := r.Header.Get("Authorization") if bearer != ExpectedBearer { fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedBearer, bearer) } else { fmt.Fprint(w, ExpectedMessage) } }, }, { clientConfig: HTTPClientConfig{ Authorization: &Authorization{CredentialsFile: AuthorizationCredentialsFile, Type: AuthorizationType}, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { bearer := r.Header.Get("Authorization") if bearer != ExpectedAuthenticationCredentials { fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedAuthenticationCredentials, bearer) } else { fmt.Fprint(w, ExpectedMessage) } }, }, { clientConfig: HTTPClientConfig{ Authorization: &Authorization{ Credentials: AuthorizationCredentials, Type: AuthorizationType, }, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { bearer := r.Header.Get("Authorization") if bearer != ExpectedAuthenticationCredentials { fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedAuthenticationCredentials, bearer) } else { fmt.Fprint(w, ExpectedMessage) } }, }, { clientConfig: HTTPClientConfig{ Authorization: &Authorization{ CredentialsFile: BearerTokenFile, }, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { bearer := r.Header.Get("Authorization") if bearer != ExpectedBearer { fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedBearer, bearer) } else { fmt.Fprint(w, ExpectedMessage) } }, }, { clientConfig: HTTPClientConfig{ BasicAuth: &BasicAuth{ Username: ExpectedUsername, Password: ExpectedPassword, }, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { username, password, ok := r.BasicAuth() if !ok { fmt.Fprintf(w, "The Authorization header wasn't set") } else if ExpectedUsername != username { fmt.Fprintf(w, "The expected username (%s) differs from the obtained username (%s).", ExpectedUsername, username) } else if ExpectedPassword != password { fmt.Fprintf(w, "The expected password (%s) differs from the obtained password (%s).", ExpectedPassword, password) } else { fmt.Fprint(w, ExpectedMessage) } }, }, { clientConfig: HTTPClientConfig{ FollowRedirects: true, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/redirected": fmt.Fprintf(w, ExpectedMessage) default: w.Header().Set("Location", "/redirected") w.WriteHeader(http.StatusFound) fmt.Fprintf(w, "It should follow the redirect.") } }, }, { clientConfig: HTTPClientConfig{ FollowRedirects: false, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, }, handler: func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/redirected": fmt.Fprint(w, "The redirection was followed.") default: w.Header().Set("Location", "/redirected") w.WriteHeader(http.StatusFound) fmt.Fprintf(w, ExpectedMessage) } }, }, } for _, validConfig := range newClientValidConfig { testServer, err := newTestServer(validConfig.handler) if err != nil { t.Fatal(err.Error()) } defer testServer.Close() err = validConfig.clientConfig.Validate() if err != nil { t.Fatal(err.Error()) } client, err := NewClientFromConfig(validConfig.clientConfig, "test") if err != nil { t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig) continue } response, err := client.Get(testServer.URL) if err != nil { t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err) continue } message, err := ioutil.ReadAll(response.Body) response.Body.Close() if err != nil { t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig) continue } trimMessage := strings.TrimSpace(string(message)) if ExpectedMessage != trimMessage { t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v", ExpectedMessage, trimMessage, validConfig.clientConfig) } } } func TestNewClientFromInvalidConfig(t *testing.T) { var newClientInvalidConfig = []struct { clientConfig HTTPClientConfig errorMsg string }{ { clientConfig: HTTPClientConfig{ TLSConfig: TLSConfig{ CAFile: MissingCA, InsecureSkipVerify: true}, }, errorMsg: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA), }, { clientConfig: HTTPClientConfig{ TLSConfig: TLSConfig{ CAFile: InvalidCA, InsecureSkipVerify: true}, }, errorMsg: fmt.Sprintf("unable to use specified CA cert %s", InvalidCA), }, } for _, invalidConfig := range newClientInvalidConfig { client, err := NewClientFromConfig(invalidConfig.clientConfig, "test") if client != nil { t.Errorf("A client instance was returned instead of nil using this config: %+v", invalidConfig.clientConfig) } if err == nil { t.Errorf("No error was returned using this config: %+v", invalidConfig.clientConfig) } if !strings.Contains(err.Error(), invalidConfig.errorMsg) { t.Errorf("Expected error %q does not contain %q", err.Error(), invalidConfig.errorMsg) } } } func TestCustomDialContextFunc(t *testing.T) { dialFn := func(_ context.Context, _, _ string) (net.Conn, error) { return nil, errors.New(ExpectedError) } cfg := HTTPClientConfig{} client, err := NewClientFromConfig(cfg, "test", WithDialContextFunc(dialFn)) if err != nil { t.Fatalf("Can't create a client from this config: %+v", cfg) } _, err = client.Get("http://localhost") if err == nil || !strings.Contains(err.Error(), ExpectedError) { t.Errorf("Expected error %q but got %q", ExpectedError, err) } } func TestMissingBearerAuthFile(t *testing.T) { cfg := HTTPClientConfig{ BearerTokenFile: MissingBearerTokenFile, TLSConfig: TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, } handler := func(w http.ResponseWriter, r *http.Request) { bearer := r.Header.Get("Authorization") if bearer != ExpectedBearer { fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedBearer, bearer) } else { fmt.Fprint(w, ExpectedMessage) } } testServer, err := newTestServer(handler) if err != nil { t.Fatal(err.Error()) } defer testServer.Close() client, err := NewClientFromConfig(cfg, "test") if err != nil { t.Fatal(err) } _, err = client.Get(testServer.URL) if err == nil { t.Fatal("No error is returned here") } if !strings.Contains(err.Error(), "unable to read authorization credentials file missing/bearer.token: open missing/bearer.token: no such file or directory") { t.Fatal("wrong error message being returned") } } func TestBearerAuthRoundTripper(t *testing.T) { const ( newBearerToken = "goodbyeandthankyouforthefish" ) fakeRoundTripper := NewRoundTripCheckRequest(func(req *http.Request) { bearer := req.Header.Get("Authorization") if bearer != ExpectedBearer { t.Errorf("The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedBearer, bearer) } }, nil, nil) // Normal flow. bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", BearerToken, fakeRoundTripper) request, _ := http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("User-Agent", "Douglas Adams mind") _, err := bearerAuthRoundTripper.RoundTrip(request) if err != nil { t.Errorf("unexpected error while executing RoundTrip: %s", err.Error()) } // Should honor already Authorization header set. bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", newBearerToken, fakeRoundTripper) request, _ = http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("Authorization", ExpectedBearer) _, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request) if err != nil { t.Errorf("unexpected error while executing RoundTrip: %s", err.Error()) } } func TestBearerAuthFileRoundTripper(t *testing.T) { fakeRoundTripper := NewRoundTripCheckRequest(func(req *http.Request) { bearer := req.Header.Get("Authorization") if bearer != ExpectedBearer { t.Errorf("The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)", ExpectedBearer, bearer) } }, nil, nil) // Normal flow. bearerAuthRoundTripper := NewAuthorizationCredentialsFileRoundTripper("Bearer", BearerTokenFile, fakeRoundTripper) request, _ := http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("User-Agent", "Douglas Adams mind") _, err := bearerAuthRoundTripper.RoundTrip(request) if err != nil { t.Errorf("unexpected error while executing RoundTrip: %s", err.Error()) } // Should honor already Authorization header set. bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsFileRoundTripper("Bearer", MissingBearerTokenFile, fakeRoundTripper) request, _ = http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("Authorization", ExpectedBearer) _, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request) if err != nil { t.Errorf("unexpected error while executing RoundTrip: %s", err.Error()) } } func TestTLSConfig(t *testing.T) { configTLSConfig := TLSConfig{ CAFile: TLSCAChainPath, CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "localhost", InsecureSkipVerify: false} tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath) if err != nil { t.Fatalf("Can't read the CA certificate chain (%s)", TLSCAChainPath) } rootCAs := x509.NewCertPool() rootCAs.AppendCertsFromPEM(tlsCAChain) expectedTLSConfig := &tls.Config{ RootCAs: rootCAs, ServerName: configTLSConfig.ServerName, InsecureSkipVerify: configTLSConfig.InsecureSkipVerify} tlsConfig, err := NewTLSConfig(&configTLSConfig) if err != nil { t.Fatalf("Can't create a new TLS Config from a configuration (%s).", err) } clientCertificate, err := tls.LoadX509KeyPair(ClientCertificatePath, ClientKeyNoPassPath) if err != nil { t.Fatalf("Can't load the client key pair ('%s' and '%s'). Reason: %s", ClientCertificatePath, ClientKeyNoPassPath, err) } cert, err := tlsConfig.GetClientCertificate(nil) if err != nil { t.Fatalf("unexpected error returned by tlsConfig.GetClientCertificate(): %s", err) } if !reflect.DeepEqual(cert, &clientCertificate) { t.Fatalf("Unexpected client certificate result: \n\n%+v\n expected\n\n%+v", cert, clientCertificate) } // tlsConfig.rootCAs.LazyCerts contains functions getCert() in go 1.16, which are // never equal. Compare the Subjects instead. if !reflect.DeepEqual(tlsConfig.RootCAs.Subjects(), expectedTLSConfig.RootCAs.Subjects()) { t.Fatalf("Unexpected RootCAs result: \n\n%+v\n expected\n\n%+v", tlsConfig.RootCAs.Subjects(), expectedTLSConfig.RootCAs.Subjects()) } tlsConfig.RootCAs = nil expectedTLSConfig.RootCAs = nil // Non-nil functions are never equal. tlsConfig.GetClientCertificate = nil if !reflect.DeepEqual(tlsConfig, expectedTLSConfig) { t.Fatalf("Unexpected TLS Config result: \n\n%+v\n expected\n\n%+v", tlsConfig, expectedTLSConfig) } } func TestTLSConfigEmpty(t *testing.T) { configTLSConfig := TLSConfig{ InsecureSkipVerify: true, } expectedTLSConfig := &tls.Config{ InsecureSkipVerify: configTLSConfig.InsecureSkipVerify, } tlsConfig, err := NewTLSConfig(&configTLSConfig) if err != nil { t.Fatalf("Can't create a new TLS Config from a configuration (%s).", err) } if !reflect.DeepEqual(tlsConfig, expectedTLSConfig) { t.Fatalf("Unexpected TLS Config result: \n\n%+v\n expected\n\n%+v", tlsConfig, expectedTLSConfig) } } func TestTLSConfigInvalidCA(t *testing.T) { var invalidTLSConfig = []struct { configTLSConfig TLSConfig errorMessage string }{ { configTLSConfig: TLSConfig{ CAFile: MissingCA, CertFile: "", KeyFile: "", ServerName: "", InsecureSkipVerify: false}, errorMessage: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA), }, { configTLSConfig: TLSConfig{ CAFile: "", CertFile: MissingCert, KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath), }, { configTLSConfig: TLSConfig{ CAFile: "", CertFile: ClientCertificatePath, KeyFile: MissingKey, ServerName: "", InsecureSkipVerify: false}, errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey), }, } for _, anInvalididTLSConfig := range invalidTLSConfig { tlsConfig, err := NewTLSConfig(&anInvalididTLSConfig.configTLSConfig) if tlsConfig != nil && err == nil { t.Errorf("The TLS Config could be created even with this %+v", anInvalididTLSConfig.configTLSConfig) continue } if !strings.Contains(err.Error(), anInvalididTLSConfig.errorMessage) { t.Errorf("The expected error should contain %s, but got %s", anInvalididTLSConfig.errorMessage, err) } } } func TestBasicAuthNoPassword(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.no-password.yaml") if err != nil { t.Fatalf("Error loading HTTP client config: %v", err) } client, err := NewClientFromConfig(*cfg, "test") if err != nil { t.Fatalf("Error creating HTTP Client: %v", err) } rt, ok := client.Transport.(*basicAuthRoundTripper) if !ok { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } if rt.username != "user" { t.Errorf("Bad HTTP client username: %s", rt.username) } if string(rt.password) != "" { t.Errorf("Expected empty HTTP client password: %s", rt.password) } if string(rt.passwordFile) != "" { t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile) } } func TestBasicAuthNoUsername(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.no-username.yaml") if err != nil { t.Fatalf("Error loading HTTP client config: %v", err) } client, err := NewClientFromConfig(*cfg, "test") if err != nil { t.Fatalf("Error creating HTTP Client: %v", err) } rt, ok := client.Transport.(*basicAuthRoundTripper) if !ok { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } if rt.username != "" { t.Errorf("Got unexpected username: %s", rt.username) } if string(rt.password) != "secret" { t.Errorf("Unexpected HTTP client password: %s", string(rt.password)) } if string(rt.passwordFile) != "" { t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile) } } func TestBasicAuthPasswordFile(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.good.yaml") if err != nil { t.Fatalf("Error loading HTTP client config: %v", err) } client, err := NewClientFromConfig(*cfg, "test") if err != nil { t.Fatalf("Error creating HTTP Client: %v", err) } rt, ok := client.Transport.(*basicAuthRoundTripper) if !ok { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } if rt.username != "user" { t.Errorf("Bad HTTP client username: %s", rt.username) } if string(rt.password) != "" { t.Errorf("Bad HTTP client password: %s", rt.password) } if string(rt.passwordFile) != "testdata/basic-auth-password" { t.Errorf("Bad HTTP client passwordFile: %s", rt.passwordFile) } } func getCertificateBlobs(t *testing.T) map[string][]byte { files := []string{ TLSCAChainPath, ClientCertificatePath, ClientKeyNoPassPath, ServerCertificatePath, ServerKeyPath, WrongClientCertPath, WrongClientKeyPath, EmptyFile, } bs := make(map[string][]byte, len(files)+1) for _, f := range files { b, err := ioutil.ReadFile(f) if err != nil { t.Fatal(err) } bs[f] = b } return bs } func writeCertificate(bs map[string][]byte, src string, dst string) { b, ok := bs[src] if !ok { panic(fmt.Sprintf("Couldn't find %q in bs", src)) } if err := ioutil.WriteFile(dst, b, 0664); err != nil { panic(err) } } func TestTLSRoundTripper(t *testing.T) { bs := getCertificateBlobs(t) tmpDir, err := ioutil.TempDir("", "tlsroundtripper") if err != nil { t.Fatal("Failed to create tmp dir", err) } defer os.RemoveAll(tmpDir) ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key") handler := func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, ExpectedMessage) } testServer, err := newTestServer(handler) if err != nil { t.Fatal(err.Error()) } defer testServer.Close() testCases := []struct { ca string cert string key string errMsg string }{ { // Valid certs. ca: TLSCAChainPath, cert: ClientCertificatePath, key: ClientKeyNoPassPath, }, { // CA not matching. ca: ClientCertificatePath, cert: ClientCertificatePath, key: ClientKeyNoPassPath, errMsg: "certificate signed by unknown authority", }, { // Invalid client cert+key. ca: TLSCAChainPath, cert: WrongClientCertPath, key: WrongClientKeyPath, errMsg: "remote error: tls", }, { // CA file empty ca: EmptyFile, cert: ClientCertificatePath, key: ClientKeyNoPassPath, errMsg: "unable to use specified CA cert", }, { // cert file empty ca: TLSCAChainPath, cert: EmptyFile, key: ClientKeyNoPassPath, errMsg: "failed to find any PEM data in certificate input", }, { // key file empty ca: TLSCAChainPath, cert: ClientCertificatePath, key: EmptyFile, errMsg: "failed to find any PEM data in key input", }, { // Valid certs again. ca: TLSCAChainPath, cert: ClientCertificatePath, key: ClientKeyNoPassPath, }, } cfg := HTTPClientConfig{ TLSConfig: TLSConfig{ CAFile: ca, CertFile: cert, KeyFile: key, InsecureSkipVerify: false}, } var c *http.Client for i, tc := range testCases { tc := tc t.Run(strconv.Itoa(i), func(t *testing.T) { writeCertificate(bs, tc.ca, ca) writeCertificate(bs, tc.cert, cert) writeCertificate(bs, tc.key, key) if c == nil { c, err = NewClientFromConfig(cfg, "test") if err != nil { t.Fatalf("Error creating HTTP Client: %v", err) } } req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) if err != nil { t.Fatalf("Error creating HTTP request: %v", err) } r, err := c.Do(req) if len(tc.errMsg) > 0 { if err == nil { r.Body.Close() t.Fatalf("Could connect to the test server.") } if !strings.Contains(err.Error(), tc.errMsg) { t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err) } return } if err != nil { t.Fatalf("Can't connect to the test server") } b, err := ioutil.ReadAll(r.Body) r.Body.Close() if err != nil { t.Errorf("Can't read the server response body") } got := strings.TrimSpace(string(b)) if ExpectedMessage != got { t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got) } }) } } func TestTLSRoundTripperRaces(t *testing.T) { bs := getCertificateBlobs(t) tmpDir, err := ioutil.TempDir("", "tlsroundtripper") if err != nil { t.Fatal("Failed to create tmp dir", err) } defer os.RemoveAll(tmpDir) ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key") handler := func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, ExpectedMessage) } testServer, err := newTestServer(handler) if err != nil { t.Fatal(err.Error()) } defer testServer.Close() cfg := HTTPClientConfig{ TLSConfig: TLSConfig{ CAFile: ca, CertFile: cert, KeyFile: key, InsecureSkipVerify: false}, } var c *http.Client writeCertificate(bs, TLSCAChainPath, ca) writeCertificate(bs, ClientCertificatePath, cert) writeCertificate(bs, ClientKeyNoPassPath, key) c, err = NewClientFromConfig(cfg, "test") if err != nil { t.Fatalf("Error creating HTTP Client: %v", err) } var wg sync.WaitGroup ch := make(chan struct{}) var total, ok int64 // Spawn 10 Go routines polling the server concurrently. for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() for { select { case <-ch: return default: atomic.AddInt64(&total, 1) r, err := c.Get(testServer.URL) if err == nil { r.Body.Close() atomic.AddInt64(&ok, 1) } } } }() } // Change the CA file every 10ms for 1 second. wg.Add(1) go func() { defer wg.Done() i := 0 for { tick := time.NewTicker(10 * time.Millisecond) <-tick.C if i%2 == 0 { writeCertificate(bs, ClientCertificatePath, ca) } else { writeCertificate(bs, TLSCAChainPath, ca) } i++ if i > 100 { close(ch) return } } }() wg.Wait() if ok == total { t.Fatalf("Expecting some requests to fail but got %d/%d successful requests", ok, total) } } func TestHideHTTPClientConfigSecrets(t *testing.T) { c, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") if err != nil { t.Errorf("Error parsing %s: %s", "testdata/http.conf.good.yml", err) } // String method must not reveal authentication credentials. s := c.String() if strings.Contains(s, "mysecret") { t.Fatal("http client config's String method reveals authentication credentials.") } } func TestDefaultFollowRedirect(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") if err != nil { t.Errorf("Error loading HTTP client config: %v", err) } if !cfg.FollowRedirects { t.Errorf("follow_redirects should be true") } } func TestValidateHTTPConfig(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") if err != nil { t.Errorf("Error loading HTTP client config: %v", err) } err = cfg.Validate() if err != nil { t.Fatalf("Error validating %s: %s", "testdata/http.conf.good.yml", err) } } func TestInvalidHTTPConfigs(t *testing.T) { for _, ee := range invalidHTTPClientConfigs { _, _, err := LoadHTTPConfigFile(ee.httpClientConfigFile) if err == nil { t.Error("Expected error with config but got none") continue } if !strings.Contains(err.Error(), ee.errMsg) { t.Errorf("Expected error for invalid HTTP client configuration to contain %q but got: %s", ee.errMsg, err) } } } // LoadHTTPConfig parses the YAML input s into a HTTPClientConfig. func LoadHTTPConfig(s string) (*HTTPClientConfig, error) { cfg := &HTTPClientConfig{} err := yaml.UnmarshalStrict([]byte(s), cfg) if err != nil { return nil, err } return cfg, nil } // LoadHTTPConfigFile parses the given YAML file into a HTTPClientConfig. func LoadHTTPConfigFile(filename string) (*HTTPClientConfig, []byte, error) { content, err := ioutil.ReadFile(filename) if err != nil { return nil, nil, err } cfg, err := LoadHTTPConfig(string(content)) if err != nil { return nil, nil, err } return cfg, content, nil } type roundTrip struct { theResponse *http.Response theError error } func (rt *roundTrip) RoundTrip(r *http.Request) (*http.Response, error) { return rt.theResponse, rt.theError } type roundTripCheckRequest struct { checkRequest func(*http.Request) roundTrip } func (rt *roundTripCheckRequest) RoundTrip(r *http.Request) (*http.Response, error) { rt.checkRequest(r) return rt.theResponse, rt.theError } // NewRoundTripCheckRequest creates a new instance of a type that implements http.RoundTripper, // which before returning theResponse and theError, executes checkRequest against a http.Request. func NewRoundTripCheckRequest(checkRequest func(*http.Request), theResponse *http.Response, theError error) http.RoundTripper { return &roundTripCheckRequest{ checkRequest: checkRequest, roundTrip: roundTrip{ theResponse: theResponse, theError: theError}} } type testServerResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` } func TestOAuth2(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { res, _ := json.Marshal(testServerResponse{ AccessToken: "12345", TokenType: "Bearer", }) w.Header().Add("Content-Type", "application/json") _, _ = w.Write(res) })) defer ts.Close() var yamlConfig = fmt.Sprintf(` client_id: 1 client_secret: 2 scopes: - A - B token_url: %s endpoint_params: hi: hello `, ts.URL) expectedConfig := OAuth2{ ClientID: "1", ClientSecret: "2", Scopes: []string{"A", "B"}, EndpointParams: map[string]string{"hi": "hello"}, TokenURL: ts.URL, } var unmarshalledConfig OAuth2 err := yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig) if err != nil { t.Fatalf("Expected no error unmarshalling yaml, got %v", err) } if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) { t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig) } rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) client := http.Client{ Transport: rt, } resp, _ := client.Get(ts.URL) authorization := resp.Request.Header.Get("Authorization") if authorization != "Bearer 12345" { t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) } } func TestOAuth2WithFile(t *testing.T) { var expectedAuth *string var previousAuth string tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if auth != *expectedAuth { t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth) } if auth == previousAuth { t.Fatal("token endpoint called twice") } previousAuth = auth res, _ := json.Marshal(testServerResponse{ AccessToken: "12345", TokenType: "Bearer", }) w.Header().Add("Content-Type", "application/json") _, _ = w.Write(res) })) defer tokenTS.Close() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if auth != "Bearer 12345" { t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth) } fmt.Fprintln(w, "Hello, client") })) defer ts.Close() secretFile, err := ioutil.TempFile("", "oauth2_secret") if err != nil { t.Fatal(err) } defer os.Remove(secretFile.Name()) var yamlConfig = fmt.Sprintf(` client_id: 1 client_secret_file: %s scopes: - A - B token_url: %s endpoint_params: hi: hello `, secretFile.Name(), tokenTS.URL) expectedConfig := OAuth2{ ClientID: "1", ClientSecretFile: secretFile.Name(), Scopes: []string{"A", "B"}, EndpointParams: map[string]string{"hi": "hello"}, TokenURL: tokenTS.URL, } var unmarshalledConfig OAuth2 err = yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig) if err != nil { t.Fatalf("Expected no error unmarshalling yaml, got %v", err) } if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) { t.Fatalf("Got unmarshalled config %q, expected %q", unmarshalledConfig, expectedConfig) } rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport) client := http.Client{ Transport: rt, } tk := "Basic MToxMjM0NTY=" expectedAuth = &tk if _, err := secretFile.Write([]byte("123456")); err != nil { t.Fatal(err) } resp, err := client.Get(ts.URL) if err != nil { t.Fatal(err) } authorization := resp.Request.Header.Get("Authorization") if authorization != "Bearer 12345" { t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) } // Making a second request with the same file content should not re-call the token API. resp, err = client.Get(ts.URL) if err != nil { t.Fatal(err) } tk = "Basic MToxMjM0NTY3" expectedAuth = &tk if _, err := secretFile.Write([]byte("7")); err != nil { t.Fatal(err) } _, err = client.Get(ts.URL) if err != nil { t.Fatal(err) } // Making a second request with the same file content should not re-call the token API. _, err = client.Get(ts.URL) if err != nil { t.Fatal(err) } authorization = resp.Request.Header.Get("Authorization") if authorization != "Bearer 12345" { t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization) } } func TestMarshalURL(t *testing.T) { urlp, err := url.Parse("http://example.com/") if err != nil { t.Fatal(err) } u := &URL{urlp} c, err := json.Marshal(u) if err != nil { t.Fatal(err) } if string(c) != "\"http://example.com/\"" { t.Fatalf("URL not properly marshaled in JSON got '%s'", string(c)) } c, err = yaml.Marshal(u) if err != nil { t.Fatal(err) } if string(c) != "http://example.com/\n" { t.Fatalf("URL not properly marshaled in YAML got '%s'", string(c)) } } func TestUnmarshalURL(t *testing.T) { b := []byte(`"http://example.com/a b"`) var u URL err := json.Unmarshal(b, &u) if err != nil { t.Fatal(err) } if u.String() != "http://example.com/a%20b" { t.Fatalf("URL not properly unmarshaled in JSON, got '%s'", u.String()) } err = yaml.Unmarshal(b, &u) if err != nil { t.Fatal(err) } if u.String() != "http://example.com/a%20b" { t.Fatalf("URL not properly unmarshaled in YAML, got '%s'", u.String()) } }