// Copyright 2020 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package protorange provides functionality to traverse a message value. package protorange import ( "bytes" "errors" "google.golang.org/protobuf/internal/genid" "google.golang.org/protobuf/internal/order" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protopath" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" ) var ( // Break breaks traversal of children in the current value. // It has no effect when traversing values that are not composite types // (e.g., messages, lists, and maps). Break = errors.New("break traversal of children in current value") // Terminate terminates the entire range operation. // All necessary Pop operations continue to be called. Terminate = errors.New("terminate range operation") ) // Range performs a depth-first traversal over reachable values in a message. // // See Options.Range for details. func Range(m protoreflect.Message, f func(protopath.Values) error) error { return Options{}.Range(m, f, nil) } // Options configures traversal of a message value tree. type Options struct { // Stable specifies whether to visit message fields and map entries // in a stable ordering. If false, then the ordering is undefined and // may be non-deterministic. // // Message fields are visited in ascending order by field number. // Map entries are visited in ascending order, where // boolean keys are ordered such that false sorts before true, // numeric keys are ordered based on the numeric value, and // string keys are lexicographically ordered by Unicode codepoints. Stable bool // Resolver is used for looking up types when expanding google.protobuf.Any // messages. If nil, this defaults to using protoregistry.GlobalTypes. // To prevent expansion of Any messages, pass an empty protoregistry.Types: // // Options{Resolver: (*protoregistry.Types)(nil)} // Resolver interface { protoregistry.ExtensionTypeResolver protoregistry.MessageTypeResolver } } // Range performs a depth-first traversal over reachable values in a message. // The first push and the last pop are to push/pop a protopath.Root step. // If push or pop return any non-nil error (other than Break or Terminate), // it terminates the traversal and is returned by Range. // // The rules for traversing a message is as follows: // // • For messages, iterate over every populated known and extension field. // Each field is preceded by a push of a protopath.FieldAccess step, // followed by recursive application of the rules on the field value, // and succeeded by a pop of that step. // If the message has unknown fields, then push an protopath.UnknownAccess step // followed immediately by pop of that step. // // • As an exception to the above rule, if the current message is a // google.protobuf.Any message, expand the underlying message (if resolvable). // The expanded message is preceded by a push of a protopath.AnyExpand step, // followed by recursive application of the rules on the underlying message, // and succeeded by a pop of that step. Mutations to the expanded message // are written back to the Any message when popping back out. // // • For lists, iterate over every element. Each element is preceded by a push // of a protopath.ListIndex step, followed by recursive application of the rules // on the list element, and succeeded by a pop of that step. // // • For maps, iterate over every entry. Each entry is preceded by a push // of a protopath.MapIndex step, followed by recursive application of the rules // on the map entry value, and succeeded by a pop of that step. // // Mutations should only be made to the last value, otherwise the effects on // traversal will be undefined. If the mutation is made to the last value // during to a push, then the effects of the mutation will affect traversal. // For example, if the last value is currently a message, and the push function // populates a few fields in that message, then the newly modified fields // will be traversed. // // The protopath.Values provided to push functions is only valid until the // corresponding pop call and the values provided to a pop call is only valid // for the duration of the pop call itself. func (o Options) Range(m protoreflect.Message, push, pop func(protopath.Values) error) error { var err error p := new(protopath.Values) if o.Resolver == nil { o.Resolver = protoregistry.GlobalTypes } pushStep(p, protopath.Root(m.Descriptor()), protoreflect.ValueOfMessage(m)) if push != nil { err = amendError(err, push(*p)) } if err == nil { err = o.rangeMessage(p, m, push, pop) } if pop != nil { err = amendError(err, pop(*p)) } popStep(p) if err == Break || err == Terminate { err = nil } return err } func (o Options) rangeMessage(p *protopath.Values, m protoreflect.Message, push, pop func(protopath.Values) error) (err error) { if ok, err := o.rangeAnyMessage(p, m, push, pop); ok { return err } fieldOrder := order.AnyFieldOrder if o.Stable { fieldOrder = order.NumberFieldOrder } order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { pushStep(p, protopath.FieldAccess(fd), v) if push != nil { err = amendError(err, push(*p)) } if err == nil { switch { case fd.IsMap(): err = o.rangeMap(p, fd, v.Map(), push, pop) case fd.IsList(): err = o.rangeList(p, fd, v.List(), push, pop) case fd.Message() != nil: err = o.rangeMessage(p, v.Message(), push, pop) } } if pop != nil { err = amendError(err, pop(*p)) } popStep(p) return err == nil }) if b := m.GetUnknown(); len(b) > 0 && err == nil { pushStep(p, protopath.UnknownAccess(), protoreflect.ValueOfBytes(b)) if push != nil { err = amendError(err, push(*p)) } if pop != nil { err = amendError(err, pop(*p)) } popStep(p) } if err == Break { err = nil } return err } func (o Options) rangeAnyMessage(p *protopath.Values, m protoreflect.Message, push, pop func(protopath.Values) error) (ok bool, err error) { md := m.Descriptor() if md.FullName() != "google.protobuf.Any" { return false, nil } fds := md.Fields() url := m.Get(fds.ByNumber(genid.Any_TypeUrl_field_number)).String() val := m.Get(fds.ByNumber(genid.Any_Value_field_number)).Bytes() mt, errFind := o.Resolver.FindMessageByURL(url) if errFind != nil { return false, nil } // Unmarshal the raw encoded message value into a structured message value. m2 := mt.New() errUnmarshal := proto.UnmarshalOptions{ Merge: true, AllowPartial: true, Resolver: o.Resolver, }.Unmarshal(val, m2.Interface()) if errUnmarshal != nil { // If the the underlying message cannot be unmarshaled, // then just treat this as an normal message type. return false, nil } // Marshal Any before ranging to detect possible mutations. b1, errMarshal := proto.MarshalOptions{ AllowPartial: true, Deterministic: true, }.Marshal(m2.Interface()) if errMarshal != nil { return true, errMarshal } pushStep(p, protopath.AnyExpand(m2.Descriptor()), protoreflect.ValueOfMessage(m2)) if push != nil { err = amendError(err, push(*p)) } if err == nil { err = o.rangeMessage(p, m2, push, pop) } if pop != nil { err = amendError(err, pop(*p)) } popStep(p) // Marshal Any after ranging to detect possible mutations. b2, errMarshal := proto.MarshalOptions{ AllowPartial: true, Deterministic: true, }.Marshal(m2.Interface()) if errMarshal != nil { return true, errMarshal } // Mutations detected, write the new sequence of bytes to the Any message. if !bytes.Equal(b1, b2) { m.Set(fds.ByNumber(genid.Any_Value_field_number), protoreflect.ValueOfBytes(b2)) } if err == Break { err = nil } return true, err } func (o Options) rangeList(p *protopath.Values, fd protoreflect.FieldDescriptor, ls protoreflect.List, push, pop func(protopath.Values) error) (err error) { for i := 0; i < ls.Len() && err == nil; i++ { v := ls.Get(i) pushStep(p, protopath.ListIndex(i), v) if push != nil { err = amendError(err, push(*p)) } if err == nil && fd.Message() != nil { err = o.rangeMessage(p, v.Message(), push, pop) } if pop != nil { err = amendError(err, pop(*p)) } popStep(p) } if err == Break { err = nil } return err } func (o Options) rangeMap(p *protopath.Values, fd protoreflect.FieldDescriptor, ms protoreflect.Map, push, pop func(protopath.Values) error) (err error) { keyOrder := order.AnyKeyOrder if o.Stable { keyOrder = order.GenericKeyOrder } order.RangeEntries(ms, keyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool { pushStep(p, protopath.MapIndex(k), v) if push != nil { err = amendError(err, push(*p)) } if err == nil && fd.MapValue().Message() != nil { err = o.rangeMessage(p, v.Message(), push, pop) } if pop != nil { err = amendError(err, pop(*p)) } popStep(p) return err == nil }) if err == Break { err = nil } return err } func pushStep(p *protopath.Values, s protopath.Step, v protoreflect.Value) { p.Path = append(p.Path, s) p.Values = append(p.Values, v) } func popStep(p *protopath.Values) { p.Path = p.Path[:len(p.Path)-1] p.Values = p.Values[:len(p.Values)-1] } // amendError amends the previous error with the current error if it is // considered more serious. The precedence order for errors is: // nil < Break < Terminate < previous non-nil < current non-nil func amendError(prev, curr error) error { switch { case curr == nil: return prev case curr == Break && prev != nil: return prev case curr == Terminate && prev != nil && prev != Break: return prev default: return curr } }