// Copyright 2017 David Ackroyd. All Rights Reserved. // See LICENSE for licensing terms. package grpc_recovery import ( "context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // RecoveryHandlerFunc is a function that recovers from the panic `p` by returning an `error`. type RecoveryHandlerFunc func(p interface{}) (err error) // RecoveryHandlerFuncContext is a function that recovers from the panic `p` by returning an `error`. // The context can be used to extract request scoped metadata and context values. type RecoveryHandlerFuncContext func(ctx context.Context, p interface{}) (err error) // UnaryServerInterceptor returns a new unary server interceptor for panic recovery. func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { o := evaluateOptions(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { panicked := true defer func() { if r := recover(); r != nil || panicked { err = recoverFrom(ctx, r, o.recoveryHandlerFunc) } }() resp, err := handler(ctx, req) panicked = false return resp, err } } // StreamServerInterceptor returns a new streaming server interceptor for panic recovery. func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { o := evaluateOptions(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { panicked := true defer func() { if r := recover(); r != nil || panicked { err = recoverFrom(stream.Context(), r, o.recoveryHandlerFunc) } }() err = handler(srv, stream) panicked = false return err } } func recoverFrom(ctx context.Context, p interface{}, r RecoveryHandlerFuncContext) error { if r == nil { return status.Errorf(codes.Internal, "%v", p) } return r(ctx, p) }