1 | // Copyright 2018 The Go Authors. All rights reserved.
|
---|
2 | // Use of this source code is governed by a BSD-style
|
---|
3 | // license that can be found in the LICENSE file.
|
---|
4 |
|
---|
5 | package proto
|
---|
6 |
|
---|
7 | import (
|
---|
8 | "google.golang.org/protobuf/encoding/protowire"
|
---|
9 | "google.golang.org/protobuf/internal/encoding/messageset"
|
---|
10 | "google.golang.org/protobuf/internal/errors"
|
---|
11 | "google.golang.org/protobuf/internal/flags"
|
---|
12 | "google.golang.org/protobuf/internal/genid"
|
---|
13 | "google.golang.org/protobuf/internal/pragma"
|
---|
14 | "google.golang.org/protobuf/reflect/protoreflect"
|
---|
15 | "google.golang.org/protobuf/reflect/protoregistry"
|
---|
16 | "google.golang.org/protobuf/runtime/protoiface"
|
---|
17 | )
|
---|
18 |
|
---|
19 | // UnmarshalOptions configures the unmarshaler.
|
---|
20 | //
|
---|
21 | // Example usage:
|
---|
22 | //
|
---|
23 | // err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
|
---|
24 | type UnmarshalOptions struct {
|
---|
25 | pragma.NoUnkeyedLiterals
|
---|
26 |
|
---|
27 | // Merge merges the input into the destination message.
|
---|
28 | // The default behavior is to always reset the message before unmarshaling,
|
---|
29 | // unless Merge is specified.
|
---|
30 | Merge bool
|
---|
31 |
|
---|
32 | // AllowPartial accepts input for messages that will result in missing
|
---|
33 | // required fields. If AllowPartial is false (the default), Unmarshal will
|
---|
34 | // return an error if there are any missing required fields.
|
---|
35 | AllowPartial bool
|
---|
36 |
|
---|
37 | // If DiscardUnknown is set, unknown fields are ignored.
|
---|
38 | DiscardUnknown bool
|
---|
39 |
|
---|
40 | // Resolver is used for looking up types when unmarshaling extension fields.
|
---|
41 | // If nil, this defaults to using protoregistry.GlobalTypes.
|
---|
42 | Resolver interface {
|
---|
43 | FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
|
---|
44 | FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
|
---|
45 | }
|
---|
46 |
|
---|
47 | // RecursionLimit limits how deeply messages may be nested.
|
---|
48 | // If zero, a default limit is applied.
|
---|
49 | RecursionLimit int
|
---|
50 | }
|
---|
51 |
|
---|
52 | // Unmarshal parses the wire-format message in b and places the result in m.
|
---|
53 | // The provided message must be mutable (e.g., a non-nil pointer to a message).
|
---|
54 | func Unmarshal(b []byte, m Message) error {
|
---|
55 | _, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
|
---|
56 | return err
|
---|
57 | }
|
---|
58 |
|
---|
59 | // Unmarshal parses the wire-format message in b and places the result in m.
|
---|
60 | // The provided message must be mutable (e.g., a non-nil pointer to a message).
|
---|
61 | func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
|
---|
62 | if o.RecursionLimit == 0 {
|
---|
63 | o.RecursionLimit = protowire.DefaultRecursionLimit
|
---|
64 | }
|
---|
65 | _, err := o.unmarshal(b, m.ProtoReflect())
|
---|
66 | return err
|
---|
67 | }
|
---|
68 |
|
---|
69 | // UnmarshalState parses a wire-format message and places the result in m.
|
---|
70 | //
|
---|
71 | // This method permits fine-grained control over the unmarshaler.
|
---|
72 | // Most users should use Unmarshal instead.
|
---|
73 | func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
|
---|
74 | if o.RecursionLimit == 0 {
|
---|
75 | o.RecursionLimit = protowire.DefaultRecursionLimit
|
---|
76 | }
|
---|
77 | return o.unmarshal(in.Buf, in.Message)
|
---|
78 | }
|
---|
79 |
|
---|
80 | // unmarshal is a centralized function that all unmarshal operations go through.
|
---|
81 | // For profiling purposes, avoid changing the name of this function or
|
---|
82 | // introducing other code paths for unmarshal that do not go through this.
|
---|
83 | func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
|
---|
84 | if o.Resolver == nil {
|
---|
85 | o.Resolver = protoregistry.GlobalTypes
|
---|
86 | }
|
---|
87 | if !o.Merge {
|
---|
88 | Reset(m.Interface())
|
---|
89 | }
|
---|
90 | allowPartial := o.AllowPartial
|
---|
91 | o.Merge = true
|
---|
92 | o.AllowPartial = true
|
---|
93 | methods := protoMethods(m)
|
---|
94 | if methods != nil && methods.Unmarshal != nil &&
|
---|
95 | !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
|
---|
96 | in := protoiface.UnmarshalInput{
|
---|
97 | Message: m,
|
---|
98 | Buf: b,
|
---|
99 | Resolver: o.Resolver,
|
---|
100 | Depth: o.RecursionLimit,
|
---|
101 | }
|
---|
102 | if o.DiscardUnknown {
|
---|
103 | in.Flags |= protoiface.UnmarshalDiscardUnknown
|
---|
104 | }
|
---|
105 | out, err = methods.Unmarshal(in)
|
---|
106 | } else {
|
---|
107 | o.RecursionLimit--
|
---|
108 | if o.RecursionLimit < 0 {
|
---|
109 | return out, errors.New("exceeded max recursion depth")
|
---|
110 | }
|
---|
111 | err = o.unmarshalMessageSlow(b, m)
|
---|
112 | }
|
---|
113 | if err != nil {
|
---|
114 | return out, err
|
---|
115 | }
|
---|
116 | if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
|
---|
117 | return out, nil
|
---|
118 | }
|
---|
119 | return out, checkInitialized(m)
|
---|
120 | }
|
---|
121 |
|
---|
122 | func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
|
---|
123 | _, err := o.unmarshal(b, m)
|
---|
124 | return err
|
---|
125 | }
|
---|
126 |
|
---|
127 | func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
|
---|
128 | md := m.Descriptor()
|
---|
129 | if messageset.IsMessageSet(md) {
|
---|
130 | return o.unmarshalMessageSet(b, m)
|
---|
131 | }
|
---|
132 | fields := md.Fields()
|
---|
133 | for len(b) > 0 {
|
---|
134 | // Parse the tag (field number and wire type).
|
---|
135 | num, wtyp, tagLen := protowire.ConsumeTag(b)
|
---|
136 | if tagLen < 0 {
|
---|
137 | return errDecode
|
---|
138 | }
|
---|
139 | if num > protowire.MaxValidNumber {
|
---|
140 | return errDecode
|
---|
141 | }
|
---|
142 |
|
---|
143 | // Find the field descriptor for this field number.
|
---|
144 | fd := fields.ByNumber(num)
|
---|
145 | if fd == nil && md.ExtensionRanges().Has(num) {
|
---|
146 | extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
|
---|
147 | if err != nil && err != protoregistry.NotFound {
|
---|
148 | return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
|
---|
149 | }
|
---|
150 | if extType != nil {
|
---|
151 | fd = extType.TypeDescriptor()
|
---|
152 | }
|
---|
153 | }
|
---|
154 | var err error
|
---|
155 | if fd == nil {
|
---|
156 | err = errUnknown
|
---|
157 | } else if flags.ProtoLegacy {
|
---|
158 | if fd.IsWeak() && fd.Message().IsPlaceholder() {
|
---|
159 | err = errUnknown // weak referent is not linked in
|
---|
160 | }
|
---|
161 | }
|
---|
162 |
|
---|
163 | // Parse the field value.
|
---|
164 | var valLen int
|
---|
165 | switch {
|
---|
166 | case err != nil:
|
---|
167 | case fd.IsList():
|
---|
168 | valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
|
---|
169 | case fd.IsMap():
|
---|
170 | valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
|
---|
171 | default:
|
---|
172 | valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
|
---|
173 | }
|
---|
174 | if err != nil {
|
---|
175 | if err != errUnknown {
|
---|
176 | return err
|
---|
177 | }
|
---|
178 | valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
|
---|
179 | if valLen < 0 {
|
---|
180 | return errDecode
|
---|
181 | }
|
---|
182 | if !o.DiscardUnknown {
|
---|
183 | m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
|
---|
184 | }
|
---|
185 | }
|
---|
186 | b = b[tagLen+valLen:]
|
---|
187 | }
|
---|
188 | return nil
|
---|
189 | }
|
---|
190 |
|
---|
191 | func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
|
---|
192 | v, n, err := o.unmarshalScalar(b, wtyp, fd)
|
---|
193 | if err != nil {
|
---|
194 | return 0, err
|
---|
195 | }
|
---|
196 | switch fd.Kind() {
|
---|
197 | case protoreflect.GroupKind, protoreflect.MessageKind:
|
---|
198 | m2 := m.Mutable(fd).Message()
|
---|
199 | if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
|
---|
200 | return n, err
|
---|
201 | }
|
---|
202 | default:
|
---|
203 | // Non-message scalars replace the previous value.
|
---|
204 | m.Set(fd, v)
|
---|
205 | }
|
---|
206 | return n, nil
|
---|
207 | }
|
---|
208 |
|
---|
209 | func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
|
---|
210 | if wtyp != protowire.BytesType {
|
---|
211 | return 0, errUnknown
|
---|
212 | }
|
---|
213 | b, n = protowire.ConsumeBytes(b)
|
---|
214 | if n < 0 {
|
---|
215 | return 0, errDecode
|
---|
216 | }
|
---|
217 | var (
|
---|
218 | keyField = fd.MapKey()
|
---|
219 | valField = fd.MapValue()
|
---|
220 | key protoreflect.Value
|
---|
221 | val protoreflect.Value
|
---|
222 | haveKey bool
|
---|
223 | haveVal bool
|
---|
224 | )
|
---|
225 | switch valField.Kind() {
|
---|
226 | case protoreflect.GroupKind, protoreflect.MessageKind:
|
---|
227 | val = mapv.NewValue()
|
---|
228 | }
|
---|
229 | // Map entries are represented as a two-element message with fields
|
---|
230 | // containing the key and value.
|
---|
231 | for len(b) > 0 {
|
---|
232 | num, wtyp, n := protowire.ConsumeTag(b)
|
---|
233 | if n < 0 {
|
---|
234 | return 0, errDecode
|
---|
235 | }
|
---|
236 | if num > protowire.MaxValidNumber {
|
---|
237 | return 0, errDecode
|
---|
238 | }
|
---|
239 | b = b[n:]
|
---|
240 | err = errUnknown
|
---|
241 | switch num {
|
---|
242 | case genid.MapEntry_Key_field_number:
|
---|
243 | key, n, err = o.unmarshalScalar(b, wtyp, keyField)
|
---|
244 | if err != nil {
|
---|
245 | break
|
---|
246 | }
|
---|
247 | haveKey = true
|
---|
248 | case genid.MapEntry_Value_field_number:
|
---|
249 | var v protoreflect.Value
|
---|
250 | v, n, err = o.unmarshalScalar(b, wtyp, valField)
|
---|
251 | if err != nil {
|
---|
252 | break
|
---|
253 | }
|
---|
254 | switch valField.Kind() {
|
---|
255 | case protoreflect.GroupKind, protoreflect.MessageKind:
|
---|
256 | if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
|
---|
257 | return 0, err
|
---|
258 | }
|
---|
259 | default:
|
---|
260 | val = v
|
---|
261 | }
|
---|
262 | haveVal = true
|
---|
263 | }
|
---|
264 | if err == errUnknown {
|
---|
265 | n = protowire.ConsumeFieldValue(num, wtyp, b)
|
---|
266 | if n < 0 {
|
---|
267 | return 0, errDecode
|
---|
268 | }
|
---|
269 | } else if err != nil {
|
---|
270 | return 0, err
|
---|
271 | }
|
---|
272 | b = b[n:]
|
---|
273 | }
|
---|
274 | // Every map entry should have entries for key and value, but this is not strictly required.
|
---|
275 | if !haveKey {
|
---|
276 | key = keyField.Default()
|
---|
277 | }
|
---|
278 | if !haveVal {
|
---|
279 | switch valField.Kind() {
|
---|
280 | case protoreflect.GroupKind, protoreflect.MessageKind:
|
---|
281 | default:
|
---|
282 | val = valField.Default()
|
---|
283 | }
|
---|
284 | }
|
---|
285 | mapv.Set(key.MapKey(), val)
|
---|
286 | return n, nil
|
---|
287 | }
|
---|
288 |
|
---|
289 | // errUnknown is used internally to indicate fields which should be added
|
---|
290 | // to the unknown field set of a message. It is never returned from an exported
|
---|
291 | // function.
|
---|
292 | var errUnknown = errors.New("BUG: internal error (unknown)")
|
---|
293 |
|
---|
294 | var errDecode = errors.New("cannot parse invalid wire-format data")
|
---|