/* * * Copyright 2022 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package test import ( "context" "fmt" "testing" "google.golang.org/grpc" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" testpb "google.golang.org/grpc/test/grpc_testing" ) type parentCtxkey struct{} type firstInterceptorCtxkey struct{} type secondInterceptorCtxkey struct{} type baseInterceptorCtxKey struct{} const ( parentCtxVal = "parent" firstInterceptorCtxVal = "firstInterceptor" secondInterceptorCtxVal = "secondInterceptor" baseInterceptorCtxVal = "baseInterceptor" ) // TestUnaryClientInterceptor_ContextValuePropagation verifies that a unary // interceptor receives context values specified in the context passed to the // RPC call. func (s) TestUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { errCh := testutils.NewChannel() unaryInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.Send(fmt.Errorf("unaryInt got %q in context.Val, want %q", got, parentCtxVal)) } errCh.Send(nil) return invoker(ctx, method, req, reply, cc, opts...) } // Start a stub server and use the above unary interceptor while creating a // ClientConn to it. ss := &stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, } if err := ss.Start(nil, grpc.WithUnaryInterceptor(unaryInt)); err != nil { t.Fatalf("Failed to start stub server: %v", err) } defer ss.Stop() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { t.Fatalf("ss.Client.EmptyCall() failed: %v", err) } val, err := errCh.Receive(ctx) if err != nil { t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) } if val != nil { t.Fatalf("unary interceptor failed: %v", val) } } // TestChainUnaryClientInterceptor_ContextValuePropagation verifies that a chain // of unary interceptors receive context values specified in the original call // as well as the ones specified by prior interceptors in the chain. func (s) TestChainUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { errCh := testutils.NewChannel() firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal)) } if ctx.Value(firstInterceptorCtxkey{}) != nil { errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{})) } if ctx.Value(secondInterceptorCtxkey{}) != nil { errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{})) } firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal) return invoker(firstCtx, method, req, reply, cc, opts...) } secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal)) } if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) } if ctx.Value(secondInterceptorCtxkey{}) != nil { errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{})) } secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal) return invoker(secondCtx, method, req, reply, cc, opts...) } lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal)) } if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) } if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal { errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal)) } errCh.SendContext(ctx, nil) return invoker(ctx, method, req, reply, cc, opts...) } // Start a stub server and use the above chain of interceptors while creating // a ClientConn to it. ss := &stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, } if err := ss.Start(nil, grpc.WithChainUnaryInterceptor(firstInt, secondInt, lastInt)); err != nil { t.Fatalf("Failed to start stub server: %v", err) } defer ss.Stop() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { t.Fatalf("ss.Client.EmptyCall() failed: %v", err) } val, err := errCh.Receive(ctx) if err != nil { t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) } if val != nil { t.Fatalf("unary interceptor failed: %v", val) } } // TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation verifies that // unary interceptors specified as a base interceptor or as a chain interceptor // receive context values specified in the original call as well as the ones // specified by interceptors in the chain. func (s) TestChainOnBaseUnaryClientInterceptor_ContextValuePropagation(t *testing.T) { errCh := testutils.NewChannel() baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.SendContext(ctx, fmt.Errorf("base interceptor got %q in context.Val, want %q", got, parentCtxVal)) } if ctx.Value(baseInterceptorCtxKey{}) != nil { errCh.SendContext(ctx, fmt.Errorf("baseinterceptor should not have %T in context", baseInterceptorCtxKey{})) } baseCtx := context.WithValue(ctx, baseInterceptorCtxKey{}, baseInterceptorCtxVal) return invoker(baseCtx, method, req, reply, cc, opts...) } chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, parentCtxVal)) } if got, ok := ctx.Value(baseInterceptorCtxKey{}).(string); !ok || got != baseInterceptorCtxVal { errCh.SendContext(ctx, fmt.Errorf("chain interceptor got %q in context.Val, want %q", got, baseInterceptorCtxVal)) } errCh.SendContext(ctx, nil) return invoker(ctx, method, req, reply, cc, opts...) } // Start a stub server and use the above chain of interceptors while creating // a ClientConn to it. ss := &stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, } if err := ss.Start(nil, grpc.WithUnaryInterceptor(baseInt), grpc.WithChainUnaryInterceptor(chainInt)); err != nil { t.Fatalf("Failed to start stub server: %v", err) } defer ss.Stop() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ss.Client.EmptyCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal), &testpb.Empty{}); err != nil { t.Fatalf("ss.Client.EmptyCall() failed: %v", err) } val, err := errCh.Receive(ctx) if err != nil { t.Fatalf("timeout when waiting for unary interceptor to be invoked: %v", err) } if val != nil { t.Fatalf("unary interceptor failed: %v", val) } } // TestChainStreamClientInterceptor_ContextValuePropagation verifies that a // chain of stream interceptors receive context values specified in the original // call as well as the ones specified by the prior interceptors in the chain. func (s) TestChainStreamClientInterceptor_ContextValuePropagation(t *testing.T) { errCh := testutils.NewChannel() firstInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.SendContext(ctx, fmt.Errorf("first interceptor got %q in context.Val, want %q", got, parentCtxVal)) } if ctx.Value(firstInterceptorCtxkey{}) != nil { errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", firstInterceptorCtxkey{})) } if ctx.Value(secondInterceptorCtxkey{}) != nil { errCh.SendContext(ctx, fmt.Errorf("first interceptor should not have %T in context", secondInterceptorCtxkey{})) } firstCtx := context.WithValue(ctx, firstInterceptorCtxkey{}, firstInterceptorCtxVal) return streamer(firstCtx, desc, cc, method, opts...) } secondInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, parentCtxVal)) } if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { errCh.SendContext(ctx, fmt.Errorf("second interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) } if ctx.Value(secondInterceptorCtxkey{}) != nil { errCh.SendContext(ctx, fmt.Errorf("second interceptor should not have %T in context", secondInterceptorCtxkey{})) } secondCtx := context.WithValue(ctx, secondInterceptorCtxkey{}, secondInterceptorCtxVal) return streamer(secondCtx, desc, cc, method, opts...) } lastInt := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if got, ok := ctx.Value(parentCtxkey{}).(string); !ok || got != parentCtxVal { errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, parentCtxVal)) } if got, ok := ctx.Value(firstInterceptorCtxkey{}).(string); !ok || got != firstInterceptorCtxVal { errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, firstInterceptorCtxVal)) } if got, ok := ctx.Value(secondInterceptorCtxkey{}).(string); !ok || got != secondInterceptorCtxVal { errCh.SendContext(ctx, fmt.Errorf("last interceptor got %q in context.Val, want %q", got, secondInterceptorCtxVal)) } errCh.SendContext(ctx, nil) return streamer(ctx, desc, cc, method, opts...) } // Start a stub server and use the above chain of interceptors while creating // a ClientConn to it. ss := &stubserver.StubServer{ FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { if _, err := stream.Recv(); err != nil { return err } return stream.Send(&testpb.StreamingOutputCallResponse{}) }, } if err := ss.Start(nil, grpc.WithChainStreamInterceptor(firstInt, secondInt, lastInt)); err != nil { t.Fatalf("Failed to start stub server: %v", err) } defer ss.Stop() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := ss.Client.FullDuplexCall(context.WithValue(ctx, parentCtxkey{}, parentCtxVal)); err != nil { t.Fatalf("ss.Client.FullDuplexCall() failed: %v", err) } val, err := errCh.Receive(ctx) if err != nil { t.Fatalf("timeout when waiting for stream interceptor to be invoked: %v", err) } if val != nil { t.Fatalf("stream interceptor failed: %v", val) } }