// Copyright 2013 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pretty import ( "encoding" "fmt" "reflect" "sort" ) func isZeroVal(val reflect.Value) bool { if !val.CanInterface() { return false } z := reflect.Zero(val.Type()).Interface() return reflect.DeepEqual(val.Interface(), z) } // pointerTracker is a helper for tracking pointer chasing to detect cycles. type pointerTracker struct { addrs map[uintptr]int // addr[address] = seen count lastID int ids map[uintptr]int // ids[address] = id } // track tracks following a reference (pointer, slice, map, etc). Every call to // track should be paired with a call to untrack. func (p *pointerTracker) track(ptr uintptr) { if p.addrs == nil { p.addrs = make(map[uintptr]int) } p.addrs[ptr]++ } // untrack registers that we have backtracked over the reference to the pointer. func (p *pointerTracker) untrack(ptr uintptr) { p.addrs[ptr]-- if p.addrs[ptr] == 0 { delete(p.addrs, ptr) } } // seen returns whether the pointer was previously seen along this path. func (p *pointerTracker) seen(ptr uintptr) bool { _, ok := p.addrs[ptr] return ok } // keep allocates an ID for the given address and returns it. func (p *pointerTracker) keep(ptr uintptr) int { if p.ids == nil { p.ids = make(map[uintptr]int) } if _, ok := p.ids[ptr]; !ok { p.lastID++ p.ids[ptr] = p.lastID } return p.ids[ptr] } // id returns the ID for the given address. func (p *pointerTracker) id(ptr uintptr) (int, bool) { if p.ids == nil { p.ids = make(map[uintptr]int) } id, ok := p.ids[ptr] return id, ok } // reflector adds local state to the recursive reflection logic. type reflector struct { *Config *pointerTracker } // follow handles following a possiblly-recursive reference to the given value // from the given ptr address. func (r *reflector) follow(ptr uintptr, val reflect.Value) node { if r.pointerTracker == nil { // Tracking disabled return r.val2node(val) } // If a parent already followed this, emit a reference marker if r.seen(ptr) { id := r.keep(ptr) return ref{id} } // Track the pointer we're following while on this recursive branch r.track(ptr) defer r.untrack(ptr) n := r.val2node(val) // If the recursion used this ptr, wrap it with a target marker if id, ok := r.id(ptr); ok { return target{id, n} } // Otherwise, return the node unadulterated return n } func (r *reflector) val2node(val reflect.Value) node { if !val.IsValid() { return rawVal("nil") } if val.CanInterface() { v := val.Interface() if formatter, ok := r.Formatter[val.Type()]; ok { if formatter != nil { res := reflect.ValueOf(formatter).Call([]reflect.Value{val}) return rawVal(res[0].Interface().(string)) } } else { if s, ok := v.(fmt.Stringer); ok && r.PrintStringers { return stringVal(s.String()) } if t, ok := v.(encoding.TextMarshaler); ok && r.PrintTextMarshalers { if raw, err := t.MarshalText(); err == nil { // if NOT an error return stringVal(string(raw)) } } } } switch kind := val.Kind(); kind { case reflect.Ptr: if val.IsNil() { return rawVal("nil") } return r.follow(val.Pointer(), val.Elem()) case reflect.Interface: if val.IsNil() { return rawVal("nil") } return r.val2node(val.Elem()) case reflect.String: return stringVal(val.String()) case reflect.Slice: n := list{} length := val.Len() ptr := val.Pointer() for i := 0; i < length; i++ { n = append(n, r.follow(ptr, val.Index(i))) } return n case reflect.Array: n := list{} length := val.Len() for i := 0; i < length; i++ { n = append(n, r.val2node(val.Index(i))) } return n case reflect.Map: // Extract the keys and sort them for stable iteration keys := val.MapKeys() pairs := make([]mapPair, 0, len(keys)) for _, key := range keys { pairs = append(pairs, mapPair{ key: new(formatter).compactString(r.val2node(key)), // can't be cyclic value: val.MapIndex(key), }) } sort.Sort(byKey(pairs)) // Process the keys into the final representation ptr, n := val.Pointer(), keyvals{} for _, pair := range pairs { n = append(n, keyval{ key: pair.key, val: r.follow(ptr, pair.value), }) } return n case reflect.Struct: n := keyvals{} typ := val.Type() fields := typ.NumField() for i := 0; i < fields; i++ { sf := typ.Field(i) if !r.IncludeUnexported && sf.PkgPath != "" { continue } field := val.Field(i) if r.SkipZeroFields && isZeroVal(field) { continue } n = append(n, keyval{sf.Name, r.val2node(field)}) } return n case reflect.Bool: if val.Bool() { return rawVal("true") } return rawVal("false") case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rawVal(fmt.Sprintf("%d", val.Int())) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return rawVal(fmt.Sprintf("%d", val.Uint())) case reflect.Uintptr: return rawVal(fmt.Sprintf("0x%X", val.Uint())) case reflect.Float32, reflect.Float64: return rawVal(fmt.Sprintf("%v", val.Float())) case reflect.Complex64, reflect.Complex128: return rawVal(fmt.Sprintf("%v", val.Complex())) } // Fall back to the default %#v if we can if val.CanInterface() { return rawVal(fmt.Sprintf("%#v", val.Interface())) } return rawVal(val.String()) } type mapPair struct { key string value reflect.Value } type byKey []mapPair func (v byKey) Len() int { return len(v) } func (v byKey) Swap(i, j int) { v[i], v[j] = v[j], v[i] } func (v byKey) Less(i, j int) bool { return v[i].key < v[j].key }