source: code/trunk/vendor/git.sr.ht/~sircmpwn/go-bare/unmarshal.go@ 822

Last change on this file since 822 was 822, checked in by yakumo.izuru, 22 months ago

Prefer immortal.run over runit and rc.d, use vendored modules
for convenience.

Signed-off-by: Izuru Yakumo <yakumo.izuru@…>

File size: 7.0 KB
Line 
1package bare
2
3import (
4 "bytes"
5 "errors"
6 "fmt"
7 "io"
8 "reflect"
9 "sync"
10)
11
12// A type which implements this interface will be responsible for unmarshaling
13// itself when encountered.
14type Unmarshalable interface {
15 Unmarshal(r *Reader) error
16}
17
18// Unmarshals a BARE message into val, which must be a pointer to a value of
19// the message type.
20func Unmarshal(data []byte, val interface{}) error {
21 b := bytes.NewReader(data)
22 r := NewReader(b)
23 return UnmarshalBareReader(r, val)
24}
25
26// Unmarshals a BARE message into value (val, which must be a pointer), from a
27// reader. See Unmarshal for details.
28func UnmarshalReader(r io.Reader, val interface{}) error {
29 r = newLimitedReader(r)
30 return UnmarshalBareReader(NewReader(r), val)
31}
32
33type decodeFunc func(r *Reader, v reflect.Value) error
34
35var decodeFuncCache sync.Map // map[reflect.Type]decodeFunc
36
37func UnmarshalBareReader(r *Reader, val interface{}) error {
38 t := reflect.TypeOf(val)
39 v := reflect.ValueOf(val)
40 if t.Kind() != reflect.Ptr {
41 return errors.New("Expected val to be pointer type")
42 }
43
44 return getDecoder(t.Elem())(r, v.Elem())
45}
46
47// get decoder from cache
48func getDecoder(t reflect.Type) decodeFunc {
49 if f, ok := decodeFuncCache.Load(t); ok {
50 return f.(decodeFunc)
51 }
52
53 f := decoderFunc(t)
54 decodeFuncCache.Store(t, f)
55 return f
56}
57
58var unmarshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()
59
60func decoderFunc(t reflect.Type) decodeFunc {
61 if reflect.PtrTo(t).Implements(unmarshalableInterface) {
62 return func(r *Reader, v reflect.Value) error {
63 uv := v.Addr().Interface().(Unmarshalable)
64 return uv.Unmarshal(r)
65 }
66 }
67
68 if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
69 return decodeUnion(t)
70 }
71
72 switch t.Kind() {
73 case reflect.Ptr:
74 return decodeOptional(t.Elem())
75 case reflect.Struct:
76 return decodeStruct(t)
77 case reflect.Array:
78 return decodeArray(t)
79 case reflect.Slice:
80 return decodeSlice(t)
81 case reflect.Map:
82 return decodeMap(t)
83 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
84 return decodeUint
85 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
86 return decodeInt
87 case reflect.Float32, reflect.Float64:
88 return decodeFloat
89 case reflect.Bool:
90 return decodeBool
91 case reflect.String:
92 return decodeString
93 }
94
95 return func(r *Reader, v reflect.Value) error {
96 return &UnsupportedTypeError{v.Type()}
97 }
98}
99
100func decodeOptional(t reflect.Type) decodeFunc {
101 return func(r *Reader, v reflect.Value) error {
102 s, err := r.ReadU8()
103 if err != nil {
104 return err
105 }
106
107 if s > 1 {
108 return fmt.Errorf("Invalid optional value: %#x", s)
109 }
110
111 if s == 0 {
112 return nil
113 }
114
115 v.Set(reflect.New(t))
116 return getDecoder(t)(r, v.Elem())
117 }
118}
119
120func decodeStruct(t reflect.Type) decodeFunc {
121 n := t.NumField()
122 decoders := make([]decodeFunc, n)
123 for i := 0; i < n; i++ {
124 field := t.Field(i)
125 if field.Tag.Get("bare") == "-" {
126 continue
127 }
128 decoders[i] = getDecoder(field.Type)
129 }
130
131 return func(r *Reader, v reflect.Value) error {
132 for i := 0; i < n; i++ {
133 if decoders[i] == nil {
134 continue
135 }
136 err := decoders[i](r, v.Field(i))
137 if err != nil {
138 return err
139 }
140 }
141 return nil
142 }
143}
144
145func decodeArray(t reflect.Type) decodeFunc {
146 f := getDecoder(t.Elem())
147 len := t.Len()
148
149 return func(r *Reader, v reflect.Value) error {
150 for i := 0; i < len; i++ {
151 err := f(r, v.Index(i))
152 if err != nil {
153 return err
154 }
155 }
156 return nil
157 }
158}
159
160func decodeSlice(t reflect.Type) decodeFunc {
161 elem := t.Elem()
162 f := getDecoder(elem)
163
164 return func(r *Reader, v reflect.Value) error {
165 len, err := r.ReadUint()
166 if err != nil {
167 return err
168 }
169
170 if len > maxArrayLength {
171 return fmt.Errorf("Array length %d exceeds configured limit of %d", len, maxArrayLength)
172 }
173
174 v.Set(reflect.MakeSlice(t, int(len), int(len)))
175
176 for i := 0; i < int(len); i++ {
177 if err := f(r, v.Index(i)); err != nil {
178 return err
179 }
180 }
181 return nil
182 }
183}
184
185func decodeMap(t reflect.Type) decodeFunc {
186 keyType := t.Key()
187 keyf := getDecoder(keyType)
188
189 valueType := t.Elem()
190 valf := getDecoder(valueType)
191
192 return func(r *Reader, v reflect.Value) error {
193 size, err := r.ReadUint()
194 if err != nil {
195 return err
196 }
197
198 if size > maxMapSize {
199 return fmt.Errorf("Map size %d exceeds configured limit of %d", size, maxMapSize)
200 }
201
202 v.Set(reflect.MakeMapWithSize(t, int(size)))
203
204 key := reflect.New(keyType).Elem()
205 value := reflect.New(valueType).Elem()
206
207 for i := uint64(0); i < size; i++ {
208 if err := keyf(r, key); err != nil {
209 return err
210 }
211
212 if v.MapIndex(key).Kind() > reflect.Invalid {
213 return fmt.Errorf("Encountered duplicate map key: %v", key.Interface())
214 }
215
216 if err := valf(r, value); err != nil {
217 return err
218 }
219
220 v.SetMapIndex(key, value)
221 }
222 return nil
223 }
224}
225
226func decodeUnion(t reflect.Type) decodeFunc {
227 ut, ok := unionRegistry[t]
228 if !ok {
229 return func(r *Reader, v reflect.Value) error {
230 return fmt.Errorf("Union type %s is not registered", t.Name())
231 }
232 }
233
234 decoders := make(map[uint64]decodeFunc)
235 for tag, t := range ut.types {
236 t := t
237 f := getDecoder(t)
238
239 decoders[tag] = func(r *Reader, v reflect.Value) error {
240 nv := reflect.New(t)
241 if err := f(r, nv.Elem()); err != nil {
242 return err
243 }
244
245 v.Set(nv)
246 return nil
247 }
248 }
249
250 return func(r *Reader, v reflect.Value) error {
251 tag, err := r.ReadUint()
252 if err != nil {
253 return err
254 }
255
256 if f, ok := decoders[tag]; ok {
257 return f(r, v)
258 }
259
260 return fmt.Errorf("Invalid union tag %d for type %s", tag, t.Name())
261 }
262}
263
264func decodeUint(r *Reader, v reflect.Value) error {
265 var err error
266 switch getIntKind(v.Type()) {
267 case reflect.Uint:
268 var u uint64
269 u, err = r.ReadUint()
270 v.SetUint(u)
271
272 case reflect.Uint8:
273 var u uint8
274 u, err = r.ReadU8()
275 v.SetUint(uint64(u))
276
277 case reflect.Uint16:
278 var u uint16
279 u, err = r.ReadU16()
280 v.SetUint(uint64(u))
281 case reflect.Uint32:
282 var u uint32
283 u, err = r.ReadU32()
284 v.SetUint(uint64(u))
285
286 case reflect.Uint64:
287 var u uint64
288 u, err = r.ReadU64()
289 v.SetUint(uint64(u))
290
291 default:
292 panic("not an uint")
293 }
294
295 return err
296}
297
298func decodeInt(r *Reader, v reflect.Value) error {
299 var err error
300 switch getIntKind(v.Type()) {
301 case reflect.Int:
302 var i int64
303 i, err = r.ReadInt()
304 v.SetInt(i)
305
306 case reflect.Int8:
307 var i int8
308 i, err = r.ReadI8()
309 v.SetInt(int64(i))
310
311 case reflect.Int16:
312 var i int16
313 i, err = r.ReadI16()
314 v.SetInt(int64(i))
315 case reflect.Int32:
316 var i int32
317 i, err = r.ReadI32()
318 v.SetInt(int64(i))
319
320 case reflect.Int64:
321 var i int64
322 i, err = r.ReadI64()
323 v.SetInt(int64(i))
324
325 default:
326 panic("not an int")
327 }
328
329 return err
330}
331
332func decodeFloat(r *Reader, v reflect.Value) error {
333 var err error
334 switch v.Type().Kind() {
335 case reflect.Float32:
336 var f float32
337 f, err = r.ReadF32()
338 v.SetFloat(float64(f))
339 case reflect.Float64:
340 var f float64
341 f, err = r.ReadF64()
342 v.SetFloat(f)
343 default:
344 panic("not a float")
345 }
346 return err
347}
348
349func decodeBool(r *Reader, v reflect.Value) error {
350 b, err := r.ReadBool()
351 v.SetBool(b)
352 return err
353}
354
355func decodeString(r *Reader, v reflect.Value) error {
356 s, err := r.ReadString()
357 v.SetString(s)
358 return err
359}
Note: See TracBrowser for help on using the repository browser.