1 | // Copyright 2010 The Go Authors. All rights reserved.
|
---|
2 | // Use of this source code is governed by a BSD-style
|
---|
3 | // license that can be found in the LICENSE file.
|
---|
4 |
|
---|
5 | package proto
|
---|
6 |
|
---|
7 | import (
|
---|
8 | "errors"
|
---|
9 | "fmt"
|
---|
10 | "reflect"
|
---|
11 |
|
---|
12 | "google.golang.org/protobuf/encoding/protowire"
|
---|
13 | "google.golang.org/protobuf/proto"
|
---|
14 | "google.golang.org/protobuf/reflect/protoreflect"
|
---|
15 | "google.golang.org/protobuf/reflect/protoregistry"
|
---|
16 | "google.golang.org/protobuf/runtime/protoiface"
|
---|
17 | "google.golang.org/protobuf/runtime/protoimpl"
|
---|
18 | )
|
---|
19 |
|
---|
20 | type (
|
---|
21 | // ExtensionDesc represents an extension descriptor and
|
---|
22 | // is used to interact with an extension field in a message.
|
---|
23 | //
|
---|
24 | // Variables of this type are generated in code by protoc-gen-go.
|
---|
25 | ExtensionDesc = protoimpl.ExtensionInfo
|
---|
26 |
|
---|
27 | // ExtensionRange represents a range of message extensions.
|
---|
28 | // Used in code generated by protoc-gen-go.
|
---|
29 | ExtensionRange = protoiface.ExtensionRangeV1
|
---|
30 |
|
---|
31 | // Deprecated: Do not use; this is an internal type.
|
---|
32 | Extension = protoimpl.ExtensionFieldV1
|
---|
33 |
|
---|
34 | // Deprecated: Do not use; this is an internal type.
|
---|
35 | XXX_InternalExtensions = protoimpl.ExtensionFields
|
---|
36 | )
|
---|
37 |
|
---|
38 | // ErrMissingExtension reports whether the extension was not present.
|
---|
39 | var ErrMissingExtension = errors.New("proto: missing extension")
|
---|
40 |
|
---|
41 | var errNotExtendable = errors.New("proto: not an extendable proto.Message")
|
---|
42 |
|
---|
43 | // HasExtension reports whether the extension field is present in m
|
---|
44 | // either as an explicitly populated field or as an unknown field.
|
---|
45 | func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
|
---|
46 | mr := MessageReflect(m)
|
---|
47 | if mr == nil || !mr.IsValid() {
|
---|
48 | return false
|
---|
49 | }
|
---|
50 |
|
---|
51 | // Check whether any populated known field matches the field number.
|
---|
52 | xtd := xt.TypeDescriptor()
|
---|
53 | if isValidExtension(mr.Descriptor(), xtd) {
|
---|
54 | has = mr.Has(xtd)
|
---|
55 | } else {
|
---|
56 | mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
|
---|
57 | has = int32(fd.Number()) == xt.Field
|
---|
58 | return !has
|
---|
59 | })
|
---|
60 | }
|
---|
61 |
|
---|
62 | // Check whether any unknown field matches the field number.
|
---|
63 | for b := mr.GetUnknown(); !has && len(b) > 0; {
|
---|
64 | num, _, n := protowire.ConsumeField(b)
|
---|
65 | has = int32(num) == xt.Field
|
---|
66 | b = b[n:]
|
---|
67 | }
|
---|
68 | return has
|
---|
69 | }
|
---|
70 |
|
---|
71 | // ClearExtension removes the extension field from m
|
---|
72 | // either as an explicitly populated field or as an unknown field.
|
---|
73 | func ClearExtension(m Message, xt *ExtensionDesc) {
|
---|
74 | mr := MessageReflect(m)
|
---|
75 | if mr == nil || !mr.IsValid() {
|
---|
76 | return
|
---|
77 | }
|
---|
78 |
|
---|
79 | xtd := xt.TypeDescriptor()
|
---|
80 | if isValidExtension(mr.Descriptor(), xtd) {
|
---|
81 | mr.Clear(xtd)
|
---|
82 | } else {
|
---|
83 | mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
|
---|
84 | if int32(fd.Number()) == xt.Field {
|
---|
85 | mr.Clear(fd)
|
---|
86 | return false
|
---|
87 | }
|
---|
88 | return true
|
---|
89 | })
|
---|
90 | }
|
---|
91 | clearUnknown(mr, fieldNum(xt.Field))
|
---|
92 | }
|
---|
93 |
|
---|
94 | // ClearAllExtensions clears all extensions from m.
|
---|
95 | // This includes populated fields and unknown fields in the extension range.
|
---|
96 | func ClearAllExtensions(m Message) {
|
---|
97 | mr := MessageReflect(m)
|
---|
98 | if mr == nil || !mr.IsValid() {
|
---|
99 | return
|
---|
100 | }
|
---|
101 |
|
---|
102 | mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
|
---|
103 | if fd.IsExtension() {
|
---|
104 | mr.Clear(fd)
|
---|
105 | }
|
---|
106 | return true
|
---|
107 | })
|
---|
108 | clearUnknown(mr, mr.Descriptor().ExtensionRanges())
|
---|
109 | }
|
---|
110 |
|
---|
111 | // GetExtension retrieves a proto2 extended field from m.
|
---|
112 | //
|
---|
113 | // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
|
---|
114 | // then GetExtension parses the encoded field and returns a Go value of the specified type.
|
---|
115 | // If the field is not present, then the default value is returned (if one is specified),
|
---|
116 | // otherwise ErrMissingExtension is reported.
|
---|
117 | //
|
---|
118 | // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
|
---|
119 | // then GetExtension returns the raw encoded bytes for the extension field.
|
---|
120 | func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
|
---|
121 | mr := MessageReflect(m)
|
---|
122 | if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
|
---|
123 | return nil, errNotExtendable
|
---|
124 | }
|
---|
125 |
|
---|
126 | // Retrieve the unknown fields for this extension field.
|
---|
127 | var bo protoreflect.RawFields
|
---|
128 | for bi := mr.GetUnknown(); len(bi) > 0; {
|
---|
129 | num, _, n := protowire.ConsumeField(bi)
|
---|
130 | if int32(num) == xt.Field {
|
---|
131 | bo = append(bo, bi[:n]...)
|
---|
132 | }
|
---|
133 | bi = bi[n:]
|
---|
134 | }
|
---|
135 |
|
---|
136 | // For type incomplete descriptors, only retrieve the unknown fields.
|
---|
137 | if xt.ExtensionType == nil {
|
---|
138 | return []byte(bo), nil
|
---|
139 | }
|
---|
140 |
|
---|
141 | // If the extension field only exists as unknown fields, unmarshal it.
|
---|
142 | // This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
|
---|
143 | xtd := xt.TypeDescriptor()
|
---|
144 | if !isValidExtension(mr.Descriptor(), xtd) {
|
---|
145 | return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
|
---|
146 | }
|
---|
147 | if !mr.Has(xtd) && len(bo) > 0 {
|
---|
148 | m2 := mr.New()
|
---|
149 | if err := (proto.UnmarshalOptions{
|
---|
150 | Resolver: extensionResolver{xt},
|
---|
151 | }.Unmarshal(bo, m2.Interface())); err != nil {
|
---|
152 | return nil, err
|
---|
153 | }
|
---|
154 | if m2.Has(xtd) {
|
---|
155 | mr.Set(xtd, m2.Get(xtd))
|
---|
156 | clearUnknown(mr, fieldNum(xt.Field))
|
---|
157 | }
|
---|
158 | }
|
---|
159 |
|
---|
160 | // Check whether the message has the extension field set or a default.
|
---|
161 | var pv protoreflect.Value
|
---|
162 | switch {
|
---|
163 | case mr.Has(xtd):
|
---|
164 | pv = mr.Get(xtd)
|
---|
165 | case xtd.HasDefault():
|
---|
166 | pv = xtd.Default()
|
---|
167 | default:
|
---|
168 | return nil, ErrMissingExtension
|
---|
169 | }
|
---|
170 |
|
---|
171 | v := xt.InterfaceOf(pv)
|
---|
172 | rv := reflect.ValueOf(v)
|
---|
173 | if isScalarKind(rv.Kind()) {
|
---|
174 | rv2 := reflect.New(rv.Type())
|
---|
175 | rv2.Elem().Set(rv)
|
---|
176 | v = rv2.Interface()
|
---|
177 | }
|
---|
178 | return v, nil
|
---|
179 | }
|
---|
180 |
|
---|
181 | // extensionResolver is a custom extension resolver that stores a single
|
---|
182 | // extension type that takes precedence over the global registry.
|
---|
183 | type extensionResolver struct{ xt protoreflect.ExtensionType }
|
---|
184 |
|
---|
185 | func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
|
---|
186 | if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
|
---|
187 | return r.xt, nil
|
---|
188 | }
|
---|
189 | return protoregistry.GlobalTypes.FindExtensionByName(field)
|
---|
190 | }
|
---|
191 |
|
---|
192 | func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
|
---|
193 | if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
|
---|
194 | return r.xt, nil
|
---|
195 | }
|
---|
196 | return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
|
---|
197 | }
|
---|
198 |
|
---|
199 | // GetExtensions returns a list of the extensions values present in m,
|
---|
200 | // corresponding with the provided list of extension descriptors, xts.
|
---|
201 | // If an extension is missing in m, the corresponding value is nil.
|
---|
202 | func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
|
---|
203 | mr := MessageReflect(m)
|
---|
204 | if mr == nil || !mr.IsValid() {
|
---|
205 | return nil, errNotExtendable
|
---|
206 | }
|
---|
207 |
|
---|
208 | vs := make([]interface{}, len(xts))
|
---|
209 | for i, xt := range xts {
|
---|
210 | v, err := GetExtension(m, xt)
|
---|
211 | if err != nil {
|
---|
212 | if err == ErrMissingExtension {
|
---|
213 | continue
|
---|
214 | }
|
---|
215 | return vs, err
|
---|
216 | }
|
---|
217 | vs[i] = v
|
---|
218 | }
|
---|
219 | return vs, nil
|
---|
220 | }
|
---|
221 |
|
---|
222 | // SetExtension sets an extension field in m to the provided value.
|
---|
223 | func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
|
---|
224 | mr := MessageReflect(m)
|
---|
225 | if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
|
---|
226 | return errNotExtendable
|
---|
227 | }
|
---|
228 |
|
---|
229 | rv := reflect.ValueOf(v)
|
---|
230 | if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
|
---|
231 | return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
|
---|
232 | }
|
---|
233 | if rv.Kind() == reflect.Ptr {
|
---|
234 | if rv.IsNil() {
|
---|
235 | return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
|
---|
236 | }
|
---|
237 | if isScalarKind(rv.Elem().Kind()) {
|
---|
238 | v = rv.Elem().Interface()
|
---|
239 | }
|
---|
240 | }
|
---|
241 |
|
---|
242 | xtd := xt.TypeDescriptor()
|
---|
243 | if !isValidExtension(mr.Descriptor(), xtd) {
|
---|
244 | return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
|
---|
245 | }
|
---|
246 | mr.Set(xtd, xt.ValueOf(v))
|
---|
247 | clearUnknown(mr, fieldNum(xt.Field))
|
---|
248 | return nil
|
---|
249 | }
|
---|
250 |
|
---|
251 | // SetRawExtension inserts b into the unknown fields of m.
|
---|
252 | //
|
---|
253 | // Deprecated: Use Message.ProtoReflect.SetUnknown instead.
|
---|
254 | func SetRawExtension(m Message, fnum int32, b []byte) {
|
---|
255 | mr := MessageReflect(m)
|
---|
256 | if mr == nil || !mr.IsValid() {
|
---|
257 | return
|
---|
258 | }
|
---|
259 |
|
---|
260 | // Verify that the raw field is valid.
|
---|
261 | for b0 := b; len(b0) > 0; {
|
---|
262 | num, _, n := protowire.ConsumeField(b0)
|
---|
263 | if int32(num) != fnum {
|
---|
264 | panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
|
---|
265 | }
|
---|
266 | b0 = b0[n:]
|
---|
267 | }
|
---|
268 |
|
---|
269 | ClearExtension(m, &ExtensionDesc{Field: fnum})
|
---|
270 | mr.SetUnknown(append(mr.GetUnknown(), b...))
|
---|
271 | }
|
---|
272 |
|
---|
273 | // ExtensionDescs returns a list of extension descriptors found in m,
|
---|
274 | // containing descriptors for both populated extension fields in m and
|
---|
275 | // also unknown fields of m that are in the extension range.
|
---|
276 | // For the later case, an type incomplete descriptor is provided where only
|
---|
277 | // the ExtensionDesc.Field field is populated.
|
---|
278 | // The order of the extension descriptors is undefined.
|
---|
279 | func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
|
---|
280 | mr := MessageReflect(m)
|
---|
281 | if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
|
---|
282 | return nil, errNotExtendable
|
---|
283 | }
|
---|
284 |
|
---|
285 | // Collect a set of known extension descriptors.
|
---|
286 | extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
|
---|
287 | mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
|
---|
288 | if fd.IsExtension() {
|
---|
289 | xt := fd.(protoreflect.ExtensionTypeDescriptor)
|
---|
290 | if xd, ok := xt.Type().(*ExtensionDesc); ok {
|
---|
291 | extDescs[fd.Number()] = xd
|
---|
292 | }
|
---|
293 | }
|
---|
294 | return true
|
---|
295 | })
|
---|
296 |
|
---|
297 | // Collect a set of unknown extension descriptors.
|
---|
298 | extRanges := mr.Descriptor().ExtensionRanges()
|
---|
299 | for b := mr.GetUnknown(); len(b) > 0; {
|
---|
300 | num, _, n := protowire.ConsumeField(b)
|
---|
301 | if extRanges.Has(num) && extDescs[num] == nil {
|
---|
302 | extDescs[num] = nil
|
---|
303 | }
|
---|
304 | b = b[n:]
|
---|
305 | }
|
---|
306 |
|
---|
307 | // Transpose the set of descriptors into a list.
|
---|
308 | var xts []*ExtensionDesc
|
---|
309 | for num, xt := range extDescs {
|
---|
310 | if xt == nil {
|
---|
311 | xt = &ExtensionDesc{Field: int32(num)}
|
---|
312 | }
|
---|
313 | xts = append(xts, xt)
|
---|
314 | }
|
---|
315 | return xts, nil
|
---|
316 | }
|
---|
317 |
|
---|
318 | // isValidExtension reports whether xtd is a valid extension descriptor for md.
|
---|
319 | func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
|
---|
320 | return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
|
---|
321 | }
|
---|
322 |
|
---|
323 | // isScalarKind reports whether k is a protobuf scalar kind (except bytes).
|
---|
324 | // This function exists for historical reasons since the representation of
|
---|
325 | // scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
|
---|
326 | func isScalarKind(k reflect.Kind) bool {
|
---|
327 | switch k {
|
---|
328 | case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
|
---|
329 | return true
|
---|
330 | default:
|
---|
331 | return false
|
---|
332 | }
|
---|
333 | }
|
---|
334 |
|
---|
335 | // clearUnknown removes unknown fields from m where remover.Has reports true.
|
---|
336 | func clearUnknown(m protoreflect.Message, remover interface {
|
---|
337 | Has(protoreflect.FieldNumber) bool
|
---|
338 | }) {
|
---|
339 | var bo protoreflect.RawFields
|
---|
340 | for bi := m.GetUnknown(); len(bi) > 0; {
|
---|
341 | num, _, n := protowire.ConsumeField(bi)
|
---|
342 | if !remover.Has(num) {
|
---|
343 | bo = append(bo, bi[:n]...)
|
---|
344 | }
|
---|
345 | bi = bi[n:]
|
---|
346 | }
|
---|
347 | if bi := m.GetUnknown(); len(bi) != len(bo) {
|
---|
348 | m.SetUnknown(bo)
|
---|
349 | }
|
---|
350 | }
|
---|
351 |
|
---|
352 | type fieldNum protoreflect.FieldNumber
|
---|
353 |
|
---|
354 | func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
|
---|
355 | return protoreflect.FieldNumber(n1) == n2
|
---|
356 | }
|
---|