/* * * Copyright 2020 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package test import ( "context" "fmt" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/codes" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" ) type funcConfigSelector struct { f func(iresolver.RPCInfo) (*iresolver.RPCConfig, error) } func (f funcConfigSelector) SelectConfig(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) { return f.f(i) } func (s) TestConfigSelector(t *testing.T) { gotContextChan := testutils.NewChannelWithSize(1) ss := &stubserver.StubServer{ EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { gotContextChan.SendContext(ctx, ctx) return &testpb.Empty{}, nil }, } ss.R = manual.NewBuilderWithScheme("confSel") if err := ss.Start(nil); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer ss.Stop() ctxDeadline := time.Now().Add(10 * time.Second) ctx, cancel := context.WithDeadline(context.Background(), ctxDeadline) defer cancel() longCtxDeadline := time.Now().Add(30 * time.Second) longdeadlineCtx, cancel := context.WithDeadline(context.Background(), longCtxDeadline) defer cancel() shorterTimeout := 3 * time.Second testMD := metadata.MD{"footest": []string{"bazbar"}} mdOut := metadata.MD{"handler": []string{"value"}} var onCommittedCalled bool testCases := []struct { name string md metadata.MD // MD sent with RPC config *iresolver.RPCConfig // config returned by config selector csErr error // error returned by config selector wantMD metadata.MD wantDeadline time.Time wantTimeout time.Duration wantErr error }{{ name: "basic", md: testMD, config: &iresolver.RPCConfig{}, wantMD: testMD, wantDeadline: ctxDeadline, }, { name: "alter MD", md: testMD, config: &iresolver.RPCConfig{ Context: metadata.NewOutgoingContext(ctx, mdOut), }, wantMD: mdOut, wantDeadline: ctxDeadline, }, { name: "erroring SelectConfig", csErr: status.Errorf(codes.Unavailable, "cannot send RPC"), wantErr: status.Errorf(codes.Unavailable, "cannot send RPC"), }, { name: "alter timeout; remove MD", md: testMD, config: &iresolver.RPCConfig{ Context: longdeadlineCtx, // no metadata }, wantMD: nil, wantDeadline: longCtxDeadline, }, { name: "nil config", md: metadata.MD{}, config: nil, wantMD: nil, wantDeadline: ctxDeadline, }, { name: "alter timeout via method config; remove MD", md: testMD, config: &iresolver.RPCConfig{ MethodConfig: serviceconfig.MethodConfig{ Timeout: &shorterTimeout, }, }, wantMD: nil, wantTimeout: shorterTimeout, }, { name: "onCommitted callback", md: testMD, config: &iresolver.RPCConfig{ OnCommitted: func() { onCommittedCalled = true }, }, wantMD: testMD, wantDeadline: ctxDeadline, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var gotInfo *iresolver.RPCInfo state := iresolver.SetConfigSelector(resolver.State{ Addresses: []resolver.Address{{Addr: ss.Address}}, ServiceConfig: parseCfg(ss.R, "{}"), }, funcConfigSelector{ f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) { gotInfo = &i cfg := tc.config if cfg != nil && cfg.Context == nil { cfg.Context = i.Context } return cfg, tc.csErr }, }) ss.R.UpdateState(state) // Blocks until config selector is applied onCommittedCalled = false ctx := metadata.NewOutgoingContext(ctx, tc.md) startTime := time.Now() if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); fmt.Sprint(err) != fmt.Sprint(tc.wantErr) { t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.wantErr) } else if err != nil { return // remaining checks are invalid } if gotInfo == nil { t.Fatalf("no config selector data") } if want := "/grpc.testing.TestService/EmptyCall"; gotInfo.Method != want { t.Errorf("gotInfo.Method = %q; want %q", gotInfo.Method, want) } gotContextI, ok := gotContextChan.ReceiveOrFail() if !ok { t.Fatalf("no context received") } gotContext := gotContextI.(context.Context) gotMD, _ := metadata.FromOutgoingContext(gotInfo.Context) if diff := cmp.Diff(tc.md, gotMD); diff != "" { t.Errorf("gotInfo.Context contains MD %v; want %v\nDiffs: %v", gotMD, tc.md, diff) } gotMD, _ = metadata.FromIncomingContext(gotContext) // Remove entries from gotMD not in tc.wantMD (e.g. authority header). for k := range gotMD { if _, ok := tc.wantMD[k]; !ok { delete(gotMD, k) } } if diff := cmp.Diff(tc.wantMD, gotMD, cmpopts.EquateEmpty()); diff != "" { t.Errorf("received md = %v; want %v\nDiffs: %v", gotMD, tc.wantMD, diff) } wantDeadline := tc.wantDeadline if wantDeadline == (time.Time{}) { wantDeadline = startTime.Add(tc.wantTimeout) } deadlineGot, _ := gotContext.Deadline() if diff := deadlineGot.Sub(wantDeadline); diff > time.Second || diff < -time.Second { t.Errorf("received deadline = %v; want ~%v", deadlineGot, wantDeadline) } if tc.config != nil && tc.config.OnCommitted != nil && !onCommittedCalled { t.Errorf("OnCommitted callback not called") } }) } }