1 | // Copyright 2019 The Go Authors. All rights reserved.
|
---|
2 | // Use of this source code is governed by a BSD-style
|
---|
3 | // license that can be found in the LICENSE file.
|
---|
4 |
|
---|
5 | package impl
|
---|
6 |
|
---|
7 | import (
|
---|
8 | "sync"
|
---|
9 |
|
---|
10 | "google.golang.org/protobuf/internal/errors"
|
---|
11 | "google.golang.org/protobuf/reflect/protoreflect"
|
---|
12 | "google.golang.org/protobuf/runtime/protoiface"
|
---|
13 | )
|
---|
14 |
|
---|
15 | func (mi *MessageInfo) checkInitialized(in protoiface.CheckInitializedInput) (protoiface.CheckInitializedOutput, error) {
|
---|
16 | var p pointer
|
---|
17 | if ms, ok := in.Message.(*messageState); ok {
|
---|
18 | p = ms.pointer()
|
---|
19 | } else {
|
---|
20 | p = in.Message.(*messageReflectWrapper).pointer()
|
---|
21 | }
|
---|
22 | return protoiface.CheckInitializedOutput{}, mi.checkInitializedPointer(p)
|
---|
23 | }
|
---|
24 |
|
---|
25 | func (mi *MessageInfo) checkInitializedPointer(p pointer) error {
|
---|
26 | mi.init()
|
---|
27 | if !mi.needsInitCheck {
|
---|
28 | return nil
|
---|
29 | }
|
---|
30 | if p.IsNil() {
|
---|
31 | for _, f := range mi.orderedCoderFields {
|
---|
32 | if f.isRequired {
|
---|
33 | return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
|
---|
34 | }
|
---|
35 | }
|
---|
36 | return nil
|
---|
37 | }
|
---|
38 | if mi.extensionOffset.IsValid() {
|
---|
39 | e := p.Apply(mi.extensionOffset).Extensions()
|
---|
40 | if err := mi.isInitExtensions(e); err != nil {
|
---|
41 | return err
|
---|
42 | }
|
---|
43 | }
|
---|
44 | for _, f := range mi.orderedCoderFields {
|
---|
45 | if !f.isRequired && f.funcs.isInit == nil {
|
---|
46 | continue
|
---|
47 | }
|
---|
48 | fptr := p.Apply(f.offset)
|
---|
49 | if f.isPointer && fptr.Elem().IsNil() {
|
---|
50 | if f.isRequired {
|
---|
51 | return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
|
---|
52 | }
|
---|
53 | continue
|
---|
54 | }
|
---|
55 | if f.funcs.isInit == nil {
|
---|
56 | continue
|
---|
57 | }
|
---|
58 | if err := f.funcs.isInit(fptr, f); err != nil {
|
---|
59 | return err
|
---|
60 | }
|
---|
61 | }
|
---|
62 | return nil
|
---|
63 | }
|
---|
64 |
|
---|
65 | func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
|
---|
66 | if ext == nil {
|
---|
67 | return nil
|
---|
68 | }
|
---|
69 | for _, x := range *ext {
|
---|
70 | ei := getExtensionFieldInfo(x.Type())
|
---|
71 | if ei.funcs.isInit == nil {
|
---|
72 | continue
|
---|
73 | }
|
---|
74 | v := x.Value()
|
---|
75 | if !v.IsValid() {
|
---|
76 | continue
|
---|
77 | }
|
---|
78 | if err := ei.funcs.isInit(v); err != nil {
|
---|
79 | return err
|
---|
80 | }
|
---|
81 | }
|
---|
82 | return nil
|
---|
83 | }
|
---|
84 |
|
---|
85 | var (
|
---|
86 | needsInitCheckMu sync.Mutex
|
---|
87 | needsInitCheckMap sync.Map
|
---|
88 | )
|
---|
89 |
|
---|
90 | // needsInitCheck reports whether a message needs to be checked for partial initialization.
|
---|
91 | //
|
---|
92 | // It returns true if the message transitively includes any required or extension fields.
|
---|
93 | func needsInitCheck(md protoreflect.MessageDescriptor) bool {
|
---|
94 | if v, ok := needsInitCheckMap.Load(md); ok {
|
---|
95 | if has, ok := v.(bool); ok {
|
---|
96 | return has
|
---|
97 | }
|
---|
98 | }
|
---|
99 | needsInitCheckMu.Lock()
|
---|
100 | defer needsInitCheckMu.Unlock()
|
---|
101 | return needsInitCheckLocked(md)
|
---|
102 | }
|
---|
103 |
|
---|
104 | func needsInitCheckLocked(md protoreflect.MessageDescriptor) (has bool) {
|
---|
105 | if v, ok := needsInitCheckMap.Load(md); ok {
|
---|
106 | // If has is true, we've previously determined that this message
|
---|
107 | // needs init checks.
|
---|
108 | //
|
---|
109 | // If has is false, we've previously determined that it can never
|
---|
110 | // be uninitialized.
|
---|
111 | //
|
---|
112 | // If has is not a bool, we've just encountered a cycle in the
|
---|
113 | // message graph. In this case, it is safe to return false: If
|
---|
114 | // the message does have required fields, we'll detect them later
|
---|
115 | // in the graph traversal.
|
---|
116 | has, ok := v.(bool)
|
---|
117 | return ok && has
|
---|
118 | }
|
---|
119 | needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
|
---|
120 | defer func() {
|
---|
121 | needsInitCheckMap.Store(md, has)
|
---|
122 | }()
|
---|
123 | if md.RequiredNumbers().Len() > 0 {
|
---|
124 | return true
|
---|
125 | }
|
---|
126 | if md.ExtensionRanges().Len() > 0 {
|
---|
127 | return true
|
---|
128 | }
|
---|
129 | for i := 0; i < md.Fields().Len(); i++ {
|
---|
130 | fd := md.Fields().Get(i)
|
---|
131 | // Map keys are never messages, so just consider the map value.
|
---|
132 | if fd.IsMap() {
|
---|
133 | fd = fd.MapValue()
|
---|
134 | }
|
---|
135 | fmd := fd.Message()
|
---|
136 | if fmd != nil && needsInitCheckLocked(fmd) {
|
---|
137 | return true
|
---|
138 | }
|
---|
139 | }
|
---|
140 | return false
|
---|
141 | }
|
---|