1 | // Copyright 2019 The Go Authors. All rights reserved.
|
---|
2 | // Use of this source code is governed by a BSD-style
|
---|
3 | // license that can be found in the LICENSE file.
|
---|
4 |
|
---|
5 | package proto
|
---|
6 |
|
---|
7 | import (
|
---|
8 | "google.golang.org/protobuf/internal/errors"
|
---|
9 | "google.golang.org/protobuf/reflect/protoreflect"
|
---|
10 | "google.golang.org/protobuf/runtime/protoiface"
|
---|
11 | )
|
---|
12 |
|
---|
13 | // CheckInitialized returns an error if any required fields in m are not set.
|
---|
14 | func CheckInitialized(m Message) error {
|
---|
15 | // Treat a nil message interface as an "untyped" empty message,
|
---|
16 | // which we assume to have no required fields.
|
---|
17 | if m == nil {
|
---|
18 | return nil
|
---|
19 | }
|
---|
20 |
|
---|
21 | return checkInitialized(m.ProtoReflect())
|
---|
22 | }
|
---|
23 |
|
---|
24 | // CheckInitialized returns an error if any required fields in m are not set.
|
---|
25 | func checkInitialized(m protoreflect.Message) error {
|
---|
26 | if methods := protoMethods(m); methods != nil && methods.CheckInitialized != nil {
|
---|
27 | _, err := methods.CheckInitialized(protoiface.CheckInitializedInput{
|
---|
28 | Message: m,
|
---|
29 | })
|
---|
30 | return err
|
---|
31 | }
|
---|
32 | return checkInitializedSlow(m)
|
---|
33 | }
|
---|
34 |
|
---|
35 | func checkInitializedSlow(m protoreflect.Message) error {
|
---|
36 | md := m.Descriptor()
|
---|
37 | fds := md.Fields()
|
---|
38 | for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
|
---|
39 | fd := fds.ByNumber(nums.Get(i))
|
---|
40 | if !m.Has(fd) {
|
---|
41 | return errors.RequiredNotSet(string(fd.FullName()))
|
---|
42 | }
|
---|
43 | }
|
---|
44 | var err error
|
---|
45 | m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
|
---|
46 | switch {
|
---|
47 | case fd.IsList():
|
---|
48 | if fd.Message() == nil {
|
---|
49 | return true
|
---|
50 | }
|
---|
51 | for i, list := 0, v.List(); i < list.Len() && err == nil; i++ {
|
---|
52 | err = checkInitialized(list.Get(i).Message())
|
---|
53 | }
|
---|
54 | case fd.IsMap():
|
---|
55 | if fd.MapValue().Message() == nil {
|
---|
56 | return true
|
---|
57 | }
|
---|
58 | v.Map().Range(func(key protoreflect.MapKey, v protoreflect.Value) bool {
|
---|
59 | err = checkInitialized(v.Message())
|
---|
60 | return err == nil
|
---|
61 | })
|
---|
62 | default:
|
---|
63 | if fd.Message() == nil {
|
---|
64 | return true
|
---|
65 | }
|
---|
66 | err = checkInitialized(v.Message())
|
---|
67 | }
|
---|
68 | return err == nil
|
---|
69 | })
|
---|
70 | return err
|
---|
71 | }
|
---|