/* * * Copyright 2020 gRPC authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package grpctest import ( "errors" "fmt" "os" "regexp" "runtime/debug" "strconv" "strings" "sync" "testing" "google.golang.org/grpc/grpclog" ) // TLogger serves as the grpclog logger and is the interface through which // expected errors are declared in tests. var TLogger *tLogger const callingFrame = 4 type logType int const ( logLog logType = iota errorLog fatalLog ) type tLogger struct { v int t *testing.T initialized bool m sync.Mutex // protects errors errors map[*regexp.Regexp]int } func init() { TLogger = &tLogger{errors: map[*regexp.Regexp]int{}} vLevel := os.Getenv("GRPC_GO_LOG_VERBOSITY_LEVEL") if vl, err := strconv.Atoi(vLevel); err == nil { TLogger.v = vl } } // getStackFrame gets, from the stack byte string, the appropriate stack frame. func getStackFrame(stack []byte, frame int) (string, error) { s := strings.Split(string(stack), "\n") if frame >= (len(s)-1)/2 { return "", errors.New("frame request out-of-bounds") } split := strings.Split(strings.Fields(s[(frame*2)+2][1:])[0], "/") return fmt.Sprintf("%v:", split[len(split)-1]), nil } // log logs the message with the specified parameters to the tLogger. func (g *tLogger) log(ltype logType, depth int, format string, args ...interface{}) { s := debug.Stack() prefix, err := getStackFrame(s, callingFrame+depth) args = append([]interface{}{prefix}, args...) if err != nil { g.t.Error(err) return } if format == "" { switch ltype { case errorLog: // fmt.Sprintln is used rather than fmt.Sprint because t.Log uses fmt.Sprintln behavior. if g.expected(fmt.Sprintln(args...)) { g.t.Log(args...) } else { g.t.Error(args...) } case fatalLog: panic(fmt.Sprint(args...)) default: g.t.Log(args...) } } else { format = "%v " + format switch ltype { case errorLog: if g.expected(fmt.Sprintf(format, args...)) { g.t.Logf(format, args...) } else { g.t.Errorf(format, args...) } case fatalLog: panic(fmt.Sprintf(format, args...)) default: g.t.Logf(format, args...) } } } // Update updates the testing.T that the testing logger logs to. Should be done // before every test. It also initializes the tLogger if it has not already. func (g *tLogger) Update(t *testing.T) { if !g.initialized { grpclog.SetLoggerV2(TLogger) g.initialized = true } g.t = t g.m.Lock() defer g.m.Unlock() g.errors = map[*regexp.Regexp]int{} } // ExpectError declares an error to be expected. For the next test, the first // error log matching the expression (using FindString) will not cause the test // to fail. "For the next test" includes all the time until the next call to // Update(). Note that if an expected error is not encountered, this will cause // the test to fail. func (g *tLogger) ExpectError(expr string) { g.ExpectErrorN(expr, 1) } // ExpectErrorN declares an error to be expected n times. func (g *tLogger) ExpectErrorN(expr string, n int) { re, err := regexp.Compile(expr) if err != nil { g.t.Error(err) return } g.m.Lock() defer g.m.Unlock() g.errors[re] += n } // EndTest checks if expected errors were not encountered. func (g *tLogger) EndTest(t *testing.T) { g.m.Lock() defer g.m.Unlock() for re, count := range g.errors { if count > 0 { t.Errorf("Expected error '%v' not encountered", re.String()) } } g.errors = map[*regexp.Regexp]int{} } // expected determines if the error string is protected or not. func (g *tLogger) expected(s string) bool { g.m.Lock() defer g.m.Unlock() for re, count := range g.errors { if re.FindStringIndex(s) != nil { g.errors[re]-- if count <= 1 { delete(g.errors, re) } return true } } return false } func (g *tLogger) Info(args ...interface{}) { g.log(logLog, 0, "", args...) } func (g *tLogger) Infoln(args ...interface{}) { g.log(logLog, 0, "", args...) } func (g *tLogger) Infof(format string, args ...interface{}) { g.log(logLog, 0, format, args...) } func (g *tLogger) InfoDepth(depth int, args ...interface{}) { g.log(logLog, depth, "", args...) } func (g *tLogger) Warning(args ...interface{}) { g.log(logLog, 0, "", args...) } func (g *tLogger) Warningln(args ...interface{}) { g.log(logLog, 0, "", args...) } func (g *tLogger) Warningf(format string, args ...interface{}) { g.log(logLog, 0, format, args...) } func (g *tLogger) WarningDepth(depth int, args ...interface{}) { g.log(logLog, depth, "", args...) } func (g *tLogger) Error(args ...interface{}) { g.log(errorLog, 0, "", args...) } func (g *tLogger) Errorln(args ...interface{}) { g.log(errorLog, 0, "", args...) } func (g *tLogger) Errorf(format string, args ...interface{}) { g.log(errorLog, 0, format, args...) } func (g *tLogger) ErrorDepth(depth int, args ...interface{}) { g.log(errorLog, depth, "", args...) } func (g *tLogger) Fatal(args ...interface{}) { g.log(fatalLog, 0, "", args...) } func (g *tLogger) Fatalln(args ...interface{}) { g.log(fatalLog, 0, "", args...) } func (g *tLogger) Fatalf(format string, args ...interface{}) { g.log(fatalLog, 0, format, args...) } func (g *tLogger) FatalDepth(depth int, args ...interface{}) { g.log(fatalLog, depth, "", args...) } func (g *tLogger) V(l int) bool { return l <= g.v }