// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. // gRPC Server Interceptor chaining middleware. package grpc_middleware import ( "context" "google.golang.org/grpc" ) // ChainUnaryServer creates a single interceptor out of a chain of many interceptors. // // Execution is done in left-to-right order, including passing of context. // For example ChainUnaryServer(one, two, three) will execute one before two before three, and three // will see context changes of one and two. func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { n := len(interceptors) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { chainer := func(currentInter grpc.UnaryServerInterceptor, currentHandler grpc.UnaryHandler) grpc.UnaryHandler { return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) { return currentInter(currentCtx, currentReq, info, currentHandler) } } chainedHandler := handler for i := n - 1; i >= 0; i-- { chainedHandler = chainer(interceptors[i], chainedHandler) } return chainedHandler(ctx, req) } } // ChainStreamServer creates a single interceptor out of a chain of many interceptors. // // Execution is done in left-to-right order, including passing of context. // For example ChainUnaryServer(one, two, three) will execute one before two before three. // If you want to pass context between interceptors, use WrapServerStream. func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { n := len(interceptors) return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { chainer := func(currentInter grpc.StreamServerInterceptor, currentHandler grpc.StreamHandler) grpc.StreamHandler { return func(currentSrv interface{}, currentStream grpc.ServerStream) error { return currentInter(currentSrv, currentStream, info, currentHandler) } } chainedHandler := handler for i := n - 1; i >= 0; i-- { chainedHandler = chainer(interceptors[i], chainedHandler) } return chainedHandler(srv, ss) } } // ChainUnaryClient creates a single interceptor out of a chain of many interceptors. // // Execution is done in left-to-right order, including passing of context. // For example ChainUnaryClient(one, two, three) will execute one before two before three. func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor { n := len(interceptors) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { chainer := func(currentInter grpc.UnaryClientInterceptor, currentInvoker grpc.UnaryInvoker) grpc.UnaryInvoker { return func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error { return currentInter(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentInvoker, currentOpts...) } } chainedInvoker := invoker for i := n - 1; i >= 0; i-- { chainedInvoker = chainer(interceptors[i], chainedInvoker) } return chainedInvoker(ctx, method, req, reply, cc, opts...) } } // ChainStreamClient creates a single interceptor out of a chain of many interceptors. // // Execution is done in left-to-right order, including passing of context. // For example ChainStreamClient(one, two, three) will execute one before two before three. func ChainStreamClient(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { n := len(interceptors) return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { chainer := func(currentInter grpc.StreamClientInterceptor, currentStreamer grpc.Streamer) grpc.Streamer { return func(currentCtx context.Context, currentDesc *grpc.StreamDesc, currentConn *grpc.ClientConn, currentMethod string, currentOpts ...grpc.CallOption) (grpc.ClientStream, error) { return currentInter(currentCtx, currentDesc, currentConn, currentMethod, currentStreamer, currentOpts...) } } chainedStreamer := streamer for i := n - 1; i >= 0; i-- { chainedStreamer = chainer(interceptors[i], chainedStreamer) } return chainedStreamer(ctx, desc, cc, method, opts...) } } // Chain creates a single interceptor out of a chain of many interceptors. // // WithUnaryServerChain is a grpc.Server config option that accepts multiple unary interceptors. // Basically syntactic sugar. func WithUnaryServerChain(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption { return grpc.UnaryInterceptor(ChainUnaryServer(interceptors...)) } // WithStreamServerChain is a grpc.Server config option that accepts multiple stream interceptors. // Basically syntactic sugar. func WithStreamServerChain(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption { return grpc.StreamInterceptor(ChainStreamServer(interceptors...)) }