// Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package testutil import ( "bytes" "context" "errors" "fmt" "log" "os" "strings" "google.golang.org/api/option" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) // HeaderChecker defines header checking and validation rules for any outgoing metadata. type HeaderChecker struct { // Key is the header name to be checked against e.g. "x-goog-api-client". Key string // ValuesValidator validates the header values retrieved from mapping against // Key in the Headers. ValuesValidator func(values ...string) error } // HeadersEnforcer asserts that outgoing RPC headers // are present and match expectations. If the expected headers // are not present or don't match expectations, it'll invoke OnFailure // with the validation error, or instead log.Fatal if OnFailure is nil. // // It expects that every declared key will be present in the outgoing // RPC header and each value will be validated by the validation function. type HeadersEnforcer struct { // Checkers maps header keys that are expected to be sent in the metadata // of outgoing gRPC requests, against the values passed into the custom // validation functions. // // If Checkers is nil or empty, only the default header "x-goog-api-client" // will be checked for. // Otherwise, if you supply Matchers, those keys and their respective // validation functions will be checked. Checkers []*HeaderChecker // OnFailure is the function that will be invoked after all validation // failures have been composed. If OnFailure is nil, log.Fatal will be // invoked instead. OnFailure func(fmt_ string, args ...interface{}) } // StreamInterceptors returns a list of StreamClientInterceptor functions which // enforce the presence and validity of expected headers during streaming RPCs. // // For client implementations which provide their own StreamClientInterceptor(s) // these interceptors should be specified as the final elements to // WithChainStreamInterceptor. // // Alternatively, users may apply gPRC options produced from DialOptions to // apply all applicable gRPC interceptors. func (h *HeadersEnforcer) StreamInterceptors() []grpc.StreamClientInterceptor { return []grpc.StreamClientInterceptor{h.interceptStream} } // UnaryInterceptors returns a list of UnaryClientInterceptor functions which // enforce the presence and validity of expected headers during unary RPCs. // // For client implementations which provide their own UnaryClientInterceptor(s) // these interceptors should be specified as the final elements to // WithChainUnaryInterceptor. // // Alternatively, users may apply gPRC options produced from DialOptions to // apply all applicable gRPC interceptors. func (h *HeadersEnforcer) UnaryInterceptors() []grpc.UnaryClientInterceptor { return []grpc.UnaryClientInterceptor{h.interceptUnary} } // DialOptions returns gRPC DialOptions consisting of unary and stream interceptors // to enforce the presence and validity of expected headers. func (h *HeadersEnforcer) DialOptions() []grpc.DialOption { return []grpc.DialOption{ grpc.WithChainStreamInterceptor(h.interceptStream), grpc.WithChainUnaryInterceptor(h.interceptUnary), } } // CallOptions returns ClientOptions consisting of unary and stream interceptors // to enforce the presence and validity of expected headers. func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) { dopts := h.DialOptions() for _, dopt := range dopts { copts = append(copts, option.WithGRPCDialOption(dopt)) } return } func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { h.checkMetadata(ctx, method) return invoker(ctx, method, req, res, cc, opts...) } func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { h.checkMetadata(ctx, method) return streamer(ctx, desc, cc, method, opts...) } // XGoogClientHeaderChecker is a HeaderChecker that ensures that the "x-goog-api-client" // header is present on outgoing metadata. var XGoogClientHeaderChecker = &HeaderChecker{ Key: "x-goog-api-client", ValuesValidator: func(values ...string) error { if len(values) == 0 { return errors.New("expecting values") } for _, value := range values { switch { case strings.Contains(value, "gl-go/"): // TODO: check for exact version strings. return nil default: // Add others here. } } return errors.New("unmatched values") }, } // DefaultHeadersEnforcer returns a HeadersEnforcer that at bare minimum checks that // the "x-goog-api-client" key is present in the outgoing metadata headers. On any // validation failure, it will invoke log.Fatalf with the error message. func DefaultHeadersEnforcer() *HeadersEnforcer { return &HeadersEnforcer{ Checkers: []*HeaderChecker{XGoogClientHeaderChecker}, } } func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) { onFailure := h.OnFailure if onFailure == nil { lgr := log.New(os.Stderr, "", 0) // Do not log the time prefix, it is noisy in test failure logs. onFailure = func(fmt_ string, args ...interface{}) { lgr.Fatalf(fmt_, args...) } } md, ok := metadata.FromOutgoingContext(ctx) if !ok { onFailure("Missing metadata for method %q", method) return } checkers := h.Checkers if len(checkers) == 0 { // Instead use the default HeaderChecker. checkers = append(checkers, XGoogClientHeaderChecker) } errBuf := new(bytes.Buffer) for _, checker := range checkers { hdrKey := checker.Key outHdrValues, ok := md[hdrKey] if !ok { fmt.Fprintf(errBuf, "missing header %q\n", hdrKey) continue } if err := checker.ValuesValidator(outHdrValues...); err != nil { fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err) } } if errBuf.Len() != 0 { onFailure("For method %q, errors:\n%s", method, errBuf) return } }