[822] | 1 | package bare
|
---|
| 2 |
|
---|
| 3 | import (
|
---|
| 4 | "bytes"
|
---|
| 5 | "errors"
|
---|
| 6 | "fmt"
|
---|
| 7 | "reflect"
|
---|
| 8 | "sync"
|
---|
| 9 | )
|
---|
| 10 |
|
---|
| 11 | // A type which implements this interface will be responsible for marshaling
|
---|
| 12 | // itself when encountered.
|
---|
| 13 | type Marshalable interface {
|
---|
| 14 | Marshal(w *Writer) error
|
---|
| 15 | }
|
---|
| 16 |
|
---|
| 17 | var encoderBufferPool = sync.Pool{
|
---|
| 18 | New: func() interface{} {
|
---|
| 19 | buf := &bytes.Buffer{}
|
---|
| 20 | buf.Grow(32)
|
---|
| 21 | return buf
|
---|
| 22 | },
|
---|
| 23 | }
|
---|
| 24 |
|
---|
| 25 | // Marshals a value (val, which must be a pointer) into a BARE message.
|
---|
| 26 | //
|
---|
| 27 | // The encoding of each struct field can be customized by the format string
|
---|
| 28 | // stored under the "bare" key in the struct field's tag.
|
---|
| 29 | //
|
---|
| 30 | // As a special case, if the field tag is "-", the field is always omitted.
|
---|
| 31 | func Marshal(val interface{}) ([]byte, error) {
|
---|
| 32 | // reuse buffers from previous serializations
|
---|
| 33 | b := encoderBufferPool.Get().(*bytes.Buffer)
|
---|
| 34 | defer func() {
|
---|
| 35 | b.Reset()
|
---|
| 36 | encoderBufferPool.Put(b)
|
---|
| 37 | }()
|
---|
| 38 |
|
---|
| 39 | w := NewWriter(b)
|
---|
| 40 | err := MarshalWriter(w, val)
|
---|
| 41 |
|
---|
| 42 | msg := make([]byte, b.Len())
|
---|
| 43 | copy(msg, b.Bytes())
|
---|
| 44 |
|
---|
| 45 | return msg, err
|
---|
| 46 | }
|
---|
| 47 |
|
---|
| 48 | // Marshals a value (val, which must be a pointer) into a BARE message and
|
---|
| 49 | // writes it to a Writer. See Marshal for details.
|
---|
| 50 | func MarshalWriter(w *Writer, val interface{}) error {
|
---|
| 51 | t := reflect.TypeOf(val)
|
---|
| 52 | v := reflect.ValueOf(val)
|
---|
| 53 | if t.Kind() != reflect.Ptr {
|
---|
| 54 | return errors.New("Expected val to be pointer type")
|
---|
| 55 | }
|
---|
| 56 |
|
---|
| 57 | return getEncoder(t.Elem())(w, v.Elem())
|
---|
| 58 | }
|
---|
| 59 |
|
---|
| 60 | type encodeFunc func(w *Writer, v reflect.Value) error
|
---|
| 61 |
|
---|
| 62 | var encodeFuncCache sync.Map // map[reflect.Type]encodeFunc
|
---|
| 63 |
|
---|
| 64 | // get decoder from cache
|
---|
| 65 | func getEncoder(t reflect.Type) encodeFunc {
|
---|
| 66 | if f, ok := encodeFuncCache.Load(t); ok {
|
---|
| 67 | return f.(encodeFunc)
|
---|
| 68 | }
|
---|
| 69 |
|
---|
| 70 | f := encoderFunc(t)
|
---|
| 71 | encodeFuncCache.Store(t, f)
|
---|
| 72 | return f
|
---|
| 73 | }
|
---|
| 74 |
|
---|
| 75 | var marshalableInterface = reflect.TypeOf((*Unmarshalable)(nil)).Elem()
|
---|
| 76 |
|
---|
| 77 | func encoderFunc(t reflect.Type) encodeFunc {
|
---|
| 78 | if reflect.PtrTo(t).Implements(marshalableInterface) {
|
---|
| 79 | return func(w *Writer, v reflect.Value) error {
|
---|
| 80 | uv := v.Addr().Interface().(Marshalable)
|
---|
| 81 | return uv.Marshal(w)
|
---|
| 82 | }
|
---|
| 83 | }
|
---|
| 84 |
|
---|
| 85 | if t.Kind() == reflect.Interface && t.Implements(unionInterface) {
|
---|
| 86 | return encodeUnion(t)
|
---|
| 87 | }
|
---|
| 88 |
|
---|
| 89 | switch t.Kind() {
|
---|
| 90 | case reflect.Ptr:
|
---|
| 91 | return encodeOptional(t.Elem())
|
---|
| 92 | case reflect.Struct:
|
---|
| 93 | return encodeStruct(t)
|
---|
| 94 | case reflect.Array:
|
---|
| 95 | return encodeArray(t)
|
---|
| 96 | case reflect.Slice:
|
---|
| 97 | return encodeSlice(t)
|
---|
| 98 | case reflect.Map:
|
---|
| 99 | return encodeMap(t)
|
---|
| 100 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
---|
| 101 | return encodeUint
|
---|
| 102 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
---|
| 103 | return encodeInt
|
---|
| 104 | case reflect.Float32, reflect.Float64:
|
---|
| 105 | return encodeFloat
|
---|
| 106 | case reflect.Bool:
|
---|
| 107 | return encodeBool
|
---|
| 108 | case reflect.String:
|
---|
| 109 | return encodeString
|
---|
| 110 | }
|
---|
| 111 |
|
---|
| 112 | return func(w *Writer, v reflect.Value) error {
|
---|
| 113 | return &UnsupportedTypeError{v.Type()}
|
---|
| 114 | }
|
---|
| 115 | }
|
---|
| 116 |
|
---|
| 117 | func encodeOptional(t reflect.Type) encodeFunc {
|
---|
| 118 | return func(w *Writer, v reflect.Value) error {
|
---|
| 119 | if v.IsNil() {
|
---|
| 120 | return w.WriteBool(false)
|
---|
| 121 | }
|
---|
| 122 |
|
---|
| 123 | if err := w.WriteBool(true); err != nil {
|
---|
| 124 | return err
|
---|
| 125 | }
|
---|
| 126 |
|
---|
| 127 | return getEncoder(t)(w, v.Elem())
|
---|
| 128 | }
|
---|
| 129 | }
|
---|
| 130 |
|
---|
| 131 | func encodeStruct(t reflect.Type) encodeFunc {
|
---|
| 132 | n := t.NumField()
|
---|
| 133 | encoders := make([]encodeFunc, n)
|
---|
| 134 | for i := 0; i < n; i++ {
|
---|
| 135 | field := t.Field(i)
|
---|
| 136 | if field.Tag.Get("bare") == "-" {
|
---|
| 137 | continue
|
---|
| 138 | }
|
---|
| 139 | encoders[i] = getEncoder(field.Type)
|
---|
| 140 | }
|
---|
| 141 |
|
---|
| 142 | return func(w *Writer, v reflect.Value) error {
|
---|
| 143 | for i := 0; i < n; i++ {
|
---|
| 144 | if encoders[i] == nil {
|
---|
| 145 | continue
|
---|
| 146 | }
|
---|
| 147 | err := encoders[i](w, v.Field(i))
|
---|
| 148 | if err != nil {
|
---|
| 149 | return err
|
---|
| 150 | }
|
---|
| 151 | }
|
---|
| 152 | return nil
|
---|
| 153 | }
|
---|
| 154 | }
|
---|
| 155 |
|
---|
| 156 | func encodeArray(t reflect.Type) encodeFunc {
|
---|
| 157 | f := getEncoder(t.Elem())
|
---|
| 158 | len := t.Len()
|
---|
| 159 |
|
---|
| 160 | return func(w *Writer, v reflect.Value) error {
|
---|
| 161 | for i := 0; i < len; i++ {
|
---|
| 162 | if err := f(w, v.Index(i)); err != nil {
|
---|
| 163 | return err
|
---|
| 164 | }
|
---|
| 165 | }
|
---|
| 166 | return nil
|
---|
| 167 | }
|
---|
| 168 | }
|
---|
| 169 |
|
---|
| 170 | func encodeSlice(t reflect.Type) encodeFunc {
|
---|
| 171 | elem := t.Elem()
|
---|
| 172 | f := getEncoder(elem)
|
---|
| 173 |
|
---|
| 174 | return func(w *Writer, v reflect.Value) error {
|
---|
| 175 | if err := w.WriteUint(uint64(v.Len())); err != nil {
|
---|
| 176 | return err
|
---|
| 177 | }
|
---|
| 178 |
|
---|
| 179 | for i := 0; i < v.Len(); i++ {
|
---|
| 180 | if err := f(w, v.Index(i)); err != nil {
|
---|
| 181 | return err
|
---|
| 182 | }
|
---|
| 183 | }
|
---|
| 184 | return nil
|
---|
| 185 | }
|
---|
| 186 | }
|
---|
| 187 |
|
---|
| 188 | func encodeMap(t reflect.Type) encodeFunc {
|
---|
| 189 | keyType := t.Key()
|
---|
| 190 | keyf := getEncoder(keyType)
|
---|
| 191 |
|
---|
| 192 | valueType := t.Elem()
|
---|
| 193 | valf := getEncoder(valueType)
|
---|
| 194 |
|
---|
| 195 | return func(w *Writer, v reflect.Value) error {
|
---|
| 196 | if err := w.WriteUint(uint64(v.Len())); err != nil {
|
---|
| 197 | return err
|
---|
| 198 | }
|
---|
| 199 |
|
---|
| 200 | iter := v.MapRange()
|
---|
| 201 | for iter.Next() {
|
---|
| 202 | if err := keyf(w, iter.Key()); err != nil {
|
---|
| 203 | return err
|
---|
| 204 | }
|
---|
| 205 | if err := valf(w, iter.Value()); err != nil {
|
---|
| 206 | return err
|
---|
| 207 | }
|
---|
| 208 | }
|
---|
| 209 | return nil
|
---|
| 210 | }
|
---|
| 211 | }
|
---|
| 212 |
|
---|
| 213 | func encodeUnion(t reflect.Type) encodeFunc {
|
---|
| 214 | ut, ok := unionRegistry[t]
|
---|
| 215 | if !ok {
|
---|
| 216 | return func(w *Writer, v reflect.Value) error {
|
---|
| 217 | return fmt.Errorf("Union type %s is not registered", t.Name())
|
---|
| 218 | }
|
---|
| 219 | }
|
---|
| 220 |
|
---|
| 221 | encoders := make(map[uint64]encodeFunc)
|
---|
| 222 | for tag, t := range ut.types {
|
---|
| 223 | encoders[tag] = getEncoder(t)
|
---|
| 224 | }
|
---|
| 225 |
|
---|
| 226 | return func(w *Writer, v reflect.Value) error {
|
---|
| 227 | t := v.Elem().Type()
|
---|
| 228 | if t.Kind() == reflect.Ptr {
|
---|
| 229 | // If T is a valid union value type, *T is valid too.
|
---|
| 230 | t = t.Elem()
|
---|
| 231 | v = v.Elem()
|
---|
| 232 | }
|
---|
| 233 | tag, ok := ut.tags[t]
|
---|
| 234 | if !ok {
|
---|
| 235 | return fmt.Errorf("Invalid union value: %s", v.Elem().String())
|
---|
| 236 | }
|
---|
| 237 |
|
---|
| 238 | if err := w.WriteUint(tag); err != nil {
|
---|
| 239 | return err
|
---|
| 240 | }
|
---|
| 241 |
|
---|
| 242 | return encoders[tag](w, v.Elem())
|
---|
| 243 | }
|
---|
| 244 | }
|
---|
| 245 |
|
---|
| 246 | func encodeUint(w *Writer, v reflect.Value) error {
|
---|
| 247 | switch getIntKind(v.Type()) {
|
---|
| 248 | case reflect.Uint:
|
---|
| 249 | return w.WriteUint(v.Uint())
|
---|
| 250 |
|
---|
| 251 | case reflect.Uint8:
|
---|
| 252 | return w.WriteU8(uint8(v.Uint()))
|
---|
| 253 |
|
---|
| 254 | case reflect.Uint16:
|
---|
| 255 | return w.WriteU16(uint16(v.Uint()))
|
---|
| 256 |
|
---|
| 257 | case reflect.Uint32:
|
---|
| 258 | return w.WriteU32(uint32(v.Uint()))
|
---|
| 259 |
|
---|
| 260 | case reflect.Uint64:
|
---|
| 261 | return w.WriteU64(uint64(v.Uint()))
|
---|
| 262 | }
|
---|
| 263 |
|
---|
| 264 | panic("not uint")
|
---|
| 265 | }
|
---|
| 266 |
|
---|
| 267 | func encodeInt(w *Writer, v reflect.Value) error {
|
---|
| 268 | switch getIntKind(v.Type()) {
|
---|
| 269 | case reflect.Int:
|
---|
| 270 | return w.WriteInt(v.Int())
|
---|
| 271 |
|
---|
| 272 | case reflect.Int8:
|
---|
| 273 | return w.WriteI8(int8(v.Int()))
|
---|
| 274 |
|
---|
| 275 | case reflect.Int16:
|
---|
| 276 | return w.WriteI16(int16(v.Int()))
|
---|
| 277 |
|
---|
| 278 | case reflect.Int32:
|
---|
| 279 | return w.WriteI32(int32(v.Int()))
|
---|
| 280 |
|
---|
| 281 | case reflect.Int64:
|
---|
| 282 | return w.WriteI64(int64(v.Int()))
|
---|
| 283 | }
|
---|
| 284 |
|
---|
| 285 | panic("not int")
|
---|
| 286 | }
|
---|
| 287 |
|
---|
| 288 | func encodeFloat(w *Writer, v reflect.Value) error {
|
---|
| 289 | switch v.Type().Kind() {
|
---|
| 290 | case reflect.Float32:
|
---|
| 291 | return w.WriteF32(float32(v.Float()))
|
---|
| 292 | case reflect.Float64:
|
---|
| 293 | return w.WriteF64(v.Float())
|
---|
| 294 | }
|
---|
| 295 |
|
---|
| 296 | panic("not float")
|
---|
| 297 | }
|
---|
| 298 |
|
---|
| 299 | func encodeBool(w *Writer, v reflect.Value) error {
|
---|
| 300 | return w.WriteBool(v.Bool())
|
---|
| 301 | }
|
---|
| 302 |
|
---|
| 303 | func encodeString(w *Writer, v reflect.Value) error {
|
---|
| 304 | if v.Kind() != reflect.String {
|
---|
| 305 | panic("not string")
|
---|
| 306 | }
|
---|
| 307 | return w.WriteString(v.String())
|
---|
| 308 | }
|
---|